diff --git a/aimodel/src/subcommands/train_mono_predict.py b/aimodel/src/subcommands/train_mono_predict.py index 709bef0..349dd1c 100644 --- a/aimodel/src/subcommands/train_mono_predict.py +++ b/aimodel/src/subcommands/train_mono_predict.py @@ -113,6 +113,7 @@ def do_png(args, ai, dataset, model_params): water_predict = tf.math.argmax(water_predict, axis=-1) # [ width, height, bins ] water_actual = tf.squeeze(water[i_batch]) + # [ width, height ] water_actual = tf.math.argmax(water_actual, axis=-1) segmentation_plot(