dlr: when predicting, also display heatmap

...of positive predictions
This commit is contained in:
Starbeamrainbowlabs 2023-03-09 18:54:28 +00:00
parent 5195fe6b62
commit 436ab78438
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -262,10 +262,9 @@ colormap = colormap * 100
colormap = colormap.astype(np.uint8) colormap = colormap.astype(np.uint8)
def infer(model, image_tensor): def infer(model, image_tensor, do_argmax=True):
predictions = model.predict(tf.expand_dims((image_tensor), axis=0)) predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
predictions = tf.squeeze(predictions) predictions = tf.squeeze(predictions)
predictions = tf.argmax(predictions, axis=2)
return predictions return predictions
@ -296,6 +295,7 @@ def plot_samples_matplotlib(filepath, display_list):
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i])) plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
else: else:
plt.imshow(display_list[i]) plt.imshow(display_list[i])
plt.colorbar()
plt.savefig(filepath, dpi=200) plt.savefig(filepath, dpi=200)
@ -303,8 +303,9 @@ def plot_predictions(filepath, input_items, colormap, model):
i = 0 i = 0
for input_pair in input_items: for input_pair in input_items:
prediction_mask = infer(image_tensor=input_pair[0], model=model) prediction_mask = infer(image_tensor=input_pair[0], model=model)
prediction_mask_argmax = tf.argmax(predictions, axis=2)
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2) # label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2) prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2)
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n[")) # print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
@ -312,7 +313,8 @@ def plot_predictions(filepath, input_items, colormap, model):
filepath.replace("$$", str(i)), filepath.replace("$$", str(i)),
[ [
# input_tensor, # input_tensor,
input_pair[1], #label_colourmap input_pair[1], #label_colourmap,
prediction_mask[:,:,1],
prediction_colormap prediction_colormap
] ]
) )