diff --git a/aimodel/src/subcommands/train_predict.py b/aimodel/src/subcommands/train_predict.py index 3624cf8..4fed4b3 100644 --- a/aimodel/src/subcommands/train_predict.py +++ b/aimodel/src/subcommands/train_predict.py @@ -40,6 +40,8 @@ def run(args): if (not hasattr(args, "params")) or args.params == None: args.params = find_paramsjson(args.checkpoint) + if args.params == None: + logger.error("Error: Failed to find params.json. Please ensure it's either in the same directory as the checkpoint or 1 level above") if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None: args.read_multiplier = 0 if (not hasattr(args, "records_per_file")) or args.records_per_file == None: