mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
dlr: try plotting the label too
https://www.youtube.com/watch?v=03qwgVJbNas
This commit is contained in:
parent
be7dd91f88
commit
376eecc29f
1 changed files with 12 additions and 8 deletions
|
@ -258,24 +258,28 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
|
|||
|
||||
|
||||
def plot_predictions(filepath, input_items, colormap, model):
|
||||
for input_tensor in input_items:
|
||||
prediction_mask = infer(image_tensor=input_tensor, model=model)
|
||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
|
||||
for input_pair in input_items:
|
||||
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
||||
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
||||
|
||||
|
||||
plot_samples_matplotlib(
|
||||
filepath,
|
||||
[
|
||||
# input_tensor,
|
||||
input_items[1], #label_colourmap
|
||||
prediction_colormap
|
||||
],
|
||||
figsize=(18, 14)
|
||||
)
|
||||
|
||||
def get_inputs_from_batched(dataset, count):
|
||||
def get_from_batched(dataset, count):
|
||||
result = []
|
||||
for batched in dataset:
|
||||
items = tf.unstack(batched[0], axis=0)
|
||||
for item in items:
|
||||
items_input = tf.unstack(batched[0], axis=0)
|
||||
items_label = tf.unstack(batched[1], axis=0)
|
||||
for item in zip(items_input, items_label):
|
||||
result.append(item)
|
||||
if len(result) >= count:
|
||||
return result
|
||||
|
@ -283,13 +287,13 @@ def get_inputs_from_batched(dataset, count):
|
|||
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||
get_inputs_from_batched(dataset_train, 4),
|
||||
get_from_batched(dataset_train, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||
get_inputs_from_batched(dataset_validate, 4),
|
||||
get_from_batched(dataset_validate, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue