mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-21 17:03:00 +00:00
dlr: when predicting, also display heatmap
...of positive predictions
This commit is contained in:
parent
5195fe6b62
commit
436ab78438
1 changed files with 6 additions and 4 deletions
|
@ -262,10 +262,9 @@ colormap = colormap * 100
|
|||
colormap = colormap.astype(np.uint8)
|
||||
|
||||
|
||||
def infer(model, image_tensor):
|
||||
def infer(model, image_tensor, do_argmax=True):
|
||||
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
|
||||
predictions = tf.squeeze(predictions)
|
||||
predictions = tf.argmax(predictions, axis=2)
|
||||
return predictions
|
||||
|
||||
|
||||
|
@ -296,6 +295,7 @@ def plot_samples_matplotlib(filepath, display_list):
|
|||
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
|
||||
else:
|
||||
plt.imshow(display_list[i])
|
||||
plt.colorbar()
|
||||
plt.savefig(filepath, dpi=200)
|
||||
|
||||
|
||||
|
@ -303,8 +303,9 @@ def plot_predictions(filepath, input_items, colormap, model):
|
|||
i = 0
|
||||
for input_pair in input_items:
|
||||
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
||||
prediction_mask_argmax = tf.argmax(predictions, axis=2)
|
||||
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
||||
prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2)
|
||||
|
||||
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
|
||||
|
||||
|
@ -312,7 +313,8 @@ def plot_predictions(filepath, input_items, colormap, model):
|
|||
filepath.replace("$$", str(i)),
|
||||
[
|
||||
# input_tensor,
|
||||
input_pair[1], #label_colourmap
|
||||
input_pair[1], #label_colourmap,
|
||||
prediction_mask[:,:,1],
|
||||
prediction_colormap
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue