diff --git a/aimodel/src/subcommands/train_mono_predict.py b/aimodel/src/subcommands/train_mono_predict.py index e413f47..709bef0 100644 --- a/aimodel/src/subcommands/train_mono_predict.py +++ b/aimodel/src/subcommands/train_mono_predict.py @@ -111,8 +111,9 @@ def do_png(args, ai, dataset, model_params): for water_predict in water_predict_batch: # [ width, height, softmax_probabilities ] → [ batch, width, height ] water_predict = tf.math.argmax(water_predict, axis=-1) - # [ width, height ] + # [ width, height, bins ] water_actual = tf.squeeze(water[i_batch]) + water_actual = tf.math.argmax(water_actual, axis=-1) segmentation_plot( water_actual, water_predict,