diff --git a/aimodel/src/subcommands/train_predict.py b/aimodel/src/subcommands/train_predict.py index f1ef0da..ac84492 100644 --- a/aimodel/src/subcommands/train_predict.py +++ b/aimodel/src/subcommands/train_predict.py @@ -87,13 +87,13 @@ def run(args): logger.info(f"Records per file: {args.records_per_file}") if output_mode == MODE_JSONL: - do_jsonl(args, ai, dataset, args.model_code, model_params) + do_jsonl(args, ai, dataset, model_params) else: - do_png(args, ai, dataset, args.model_code, model_params) + do_png(args, ai, dataset, model_params) sys.stderr.write(">>> Complete\n") -def do_png(args, ai, dataset, model_code, model_params): +def do_png(args, ai, dataset, model_params): if not os.path.exists(os.path.dirname(args.output)): os.mkdir(os.path.dirname(args.output)) @@ -114,7 +114,7 @@ def do_png(args, ai, dataset, model_code, model_params): segmentation_plot( water_actual, water_predict, - model_code, + args.model_code, args.output.replace("+d", str(i)) )