diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index 3e18d1f..26581b8 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -56,6 +56,10 @@ def run(args): if not os.path.exists(args.checkpoint): raise Exception(f"Error: The specified filepath to the checkpoint to load ('{args.checkpoint}) does not exist.") + if args.records_per_file > 0: + dirpath_output=os.path.dirname(args.output) + if not os.path.exists(dirpath_output): + os.mkdir(dirpath_output) filepath_output = args.output if hasattr(args, "output") and args.output != None else "-"