mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
dlr: try plotting the label too
https://www.youtube.com/watch?v=03qwgVJbNas
This commit is contained in:
parent
be7dd91f88
commit
376eecc29f
1 changed files with 12 additions and 8 deletions
|
@ -258,24 +258,28 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
|
||||||
|
|
||||||
|
|
||||||
def plot_predictions(filepath, input_items, colormap, model):
|
def plot_predictions(filepath, input_items, colormap, model):
|
||||||
for input_tensor in input_items:
|
for input_pair in input_items:
|
||||||
prediction_mask = infer(image_tensor=input_tensor, model=model)
|
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
||||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
|
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
||||||
|
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
||||||
|
|
||||||
|
|
||||||
plot_samples_matplotlib(
|
plot_samples_matplotlib(
|
||||||
filepath,
|
filepath,
|
||||||
[
|
[
|
||||||
# input_tensor,
|
# input_tensor,
|
||||||
|
input_items[1], #label_colourmap
|
||||||
prediction_colormap
|
prediction_colormap
|
||||||
],
|
],
|
||||||
figsize=(18, 14)
|
figsize=(18, 14)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_inputs_from_batched(dataset, count):
|
def get_from_batched(dataset, count):
|
||||||
result = []
|
result = []
|
||||||
for batched in dataset:
|
for batched in dataset:
|
||||||
items = tf.unstack(batched[0], axis=0)
|
items_input = tf.unstack(batched[0], axis=0)
|
||||||
for item in items:
|
items_label = tf.unstack(batched[1], axis=0)
|
||||||
|
for item in zip(items_input, items_label):
|
||||||
result.append(item)
|
result.append(item)
|
||||||
if len(result) >= count:
|
if len(result) >= count:
|
||||||
return result
|
return result
|
||||||
|
@ -283,13 +287,13 @@ def get_inputs_from_batched(dataset, count):
|
||||||
|
|
||||||
plot_predictions(
|
plot_predictions(
|
||||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||||
get_inputs_from_batched(dataset_train, 4),
|
get_from_batched(dataset_train, 4),
|
||||||
colormap,
|
colormap,
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
plot_predictions(
|
plot_predictions(
|
||||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||||
get_inputs_from_batched(dataset_validate, 4),
|
get_from_batched(dataset_validate, 4),
|
||||||
colormap,
|
colormap,
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue