diff --git a/aimodel/src/subcommands/train_predict.py b/aimodel/src/subcommands/train_predict.py index 44ffd51..3bce9fe 100644 --- a/aimodel/src/subcommands/train_predict.py +++ b/aimodel/src/subcommands/train_predict.py @@ -88,7 +88,7 @@ def run(args): sys.stderr.write(">>> Complete\n") def do_png(args, ai, dataset, model_code): - if not os.path.exists(args.output): + if not os.path.exists(os.path.dirname(args.output)): os.mkdir(os.path.dirname(args.output)) model_params = json.loads(readfile(args.params)) @@ -106,10 +106,10 @@ def do_png(args, ai, dataset, model_code): # [ width, height, softmax_probabilities ] → [ batch, width, height ] water_predict = tf.math.argmax(water_predict, axis=-1) # [ width, height ] - water_item = water[i_batch] + water_actual = water[i_batch] segmentation_plot( - water_item, water_predict, + water_actual, water_predict, model_code, args.output.replace("+d", str(i)) )