diff --git a/aimodel/src/subcommands/train_mono_predict.py b/aimodel/src/subcommands/train_mono_predict.py index 349dd1c..4e3c5c1 100644 --- a/aimodel/src/subcommands/train_mono_predict.py +++ b/aimodel/src/subcommands/train_mono_predict.py @@ -128,7 +128,7 @@ def do_png(args, ai, dataset, model_params): if i % 100 == 0: sys.stderr.write(f"Processed {i} items\r") -def do_jsonl(args, ai, dataset, model_params): +def do_jsonl(args, ai, dataset, model_params, do_argmax=False): write_mode = "wt" if args.output.endswith(".gz") else "w" handle = sys.stdout @@ -156,7 +156,8 @@ def do_jsonl(args, ai, dataset, model_params): i_batch = 0 for water_predict in water_predict_batch: # [ width, height, softmax_probabilities ] → [ batch, width, height ] - # water_predict = tf.math.argmax(water_predict, axis=-1) + if do_argmax: + water_predict = tf.math.argmax(water_predict, axis=-1) # [ width, height ] water_actual = tf.squeeze(water_actual_batch[i_batch])