diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index ad0681c..3f190cd 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -262,10 +262,9 @@ colormap = colormap * 100 colormap = colormap.astype(np.uint8) -def infer(model, image_tensor): +def infer(model, image_tensor, do_argmax=True): predictions = model.predict(tf.expand_dims((image_tensor), axis=0)) predictions = tf.squeeze(predictions) - predictions = tf.argmax(predictions, axis=2) return predictions @@ -296,6 +295,7 @@ def plot_samples_matplotlib(filepath, display_list): plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i])) else: plt.imshow(display_list[i]) + plt.colorbar() plt.savefig(filepath, dpi=200) @@ -303,8 +303,9 @@ def plot_predictions(filepath, input_items, colormap, model): i = 0 for input_pair in input_items: prediction_mask = infer(image_tensor=input_pair[0], model=model) + prediction_mask_argmax = tf.argmax(predictions, axis=2) # label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2) - prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2) + prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2) # print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n[")) @@ -312,7 +313,8 @@ def plot_predictions(filepath, input_items, colormap, model): filepath.replace("$$", str(i)), [ # input_tensor, - input_pair[1], #label_colourmap + input_pair[1], #label_colourmap, + prediction_mask[:,:,1], prediction_colormap ] )