mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
dlr: when predicting, also display heatmap
...of positive predictions
This commit is contained in:
parent
5195fe6b62
commit
436ab78438
1 changed files with 6 additions and 4 deletions
|
@ -262,10 +262,9 @@ colormap = colormap * 100
|
||||||
colormap = colormap.astype(np.uint8)
|
colormap = colormap.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
def infer(model, image_tensor):
|
def infer(model, image_tensor, do_argmax=True):
|
||||||
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
|
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
|
||||||
predictions = tf.squeeze(predictions)
|
predictions = tf.squeeze(predictions)
|
||||||
predictions = tf.argmax(predictions, axis=2)
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
@ -296,6 +295,7 @@ def plot_samples_matplotlib(filepath, display_list):
|
||||||
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
|
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
|
||||||
else:
|
else:
|
||||||
plt.imshow(display_list[i])
|
plt.imshow(display_list[i])
|
||||||
|
plt.colorbar()
|
||||||
plt.savefig(filepath, dpi=200)
|
plt.savefig(filepath, dpi=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ -303,8 +303,9 @@ def plot_predictions(filepath, input_items, colormap, model):
|
||||||
i = 0
|
i = 0
|
||||||
for input_pair in input_items:
|
for input_pair in input_items:
|
||||||
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
||||||
|
prediction_mask_argmax = tf.argmax(predictions, axis=2)
|
||||||
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
||||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2)
|
||||||
|
|
||||||
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
|
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
|
||||||
|
|
||||||
|
@ -312,7 +313,8 @@ def plot_predictions(filepath, input_items, colormap, model):
|
||||||
filepath.replace("$$", str(i)),
|
filepath.replace("$$", str(i)),
|
||||||
[
|
[
|
||||||
# input_tensor,
|
# input_tensor,
|
||||||
input_pair[1], #label_colourmap
|
input_pair[1], #label_colourmap,
|
||||||
|
prediction_mask[:,:,1],
|
||||||
prediction_colormap
|
prediction_colormap
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue