dlr: try plotting the label too

https://www.youtube.com/watch?v=03qwgVJbNas
This commit is contained in:
Starbeamrainbowlabs 2023-01-12 16:13:04 +00:00
parent be7dd91f88
commit 376eecc29f
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -258,24 +258,28 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
def plot_predictions(filepath, input_items, colormap, model):
for input_tensor in input_items:
prediction_mask = infer(image_tensor=input_tensor, model=model)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
for input_pair in input_items:
prediction_mask = infer(image_tensor=input_pair[0], model=model)
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
plot_samples_matplotlib(
filepath,
[
# input_tensor,
input_items[1], #label_colourmap
prediction_colormap
],
figsize=(18, 14)
)
def get_inputs_from_batched(dataset, count):
def get_from_batched(dataset, count):
result = []
for batched in dataset:
items = tf.unstack(batched[0], axis=0)
for item in items:
items_input = tf.unstack(batched[0], axis=0)
items_label = tf.unstack(batched[1], axis=0)
for item in zip(items_input, items_label):
result.append(item)
if len(result) >= count:
return result
@ -283,13 +287,13 @@ def get_inputs_from_batched(dataset, count):
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train.png"),
get_inputs_from_batched(dataset_train, 4),
get_from_batched(dataset_train, 4),
colormap,
model=model
)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate.png"),
get_inputs_from_batched(dataset_validate, 4),
get_from_batched(dataset_validate, 4),
colormap,
model=model
)