dlr: try plotting the label too

https://www.youtube.com/watch?v=03qwgVJbNas
This commit is contained in:
Starbeamrainbowlabs 2023-01-12 16:13:04 +00:00
parent be7dd91f88
commit 376eecc29f
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -258,24 +258,28 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
def plot_predictions(filepath, input_items, colormap, model): def plot_predictions(filepath, input_items, colormap, model):
for input_tensor in input_items: for input_pair in input_items:
prediction_mask = infer(image_tensor=input_tensor, model=model) prediction_mask = infer(image_tensor=input_pair[0], model=model)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20) # label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
plot_samples_matplotlib( plot_samples_matplotlib(
filepath, filepath,
[ [
# input_tensor, # input_tensor,
input_items[1], #label_colourmap
prediction_colormap prediction_colormap
], ],
figsize=(18, 14) figsize=(18, 14)
) )
def get_inputs_from_batched(dataset, count): def get_from_batched(dataset, count):
result = [] result = []
for batched in dataset: for batched in dataset:
items = tf.unstack(batched[0], axis=0) items_input = tf.unstack(batched[0], axis=0)
for item in items: items_label = tf.unstack(batched[1], axis=0)
for item in zip(items_input, items_label):
result.append(item) result.append(item)
if len(result) >= count: if len(result) >= count:
return result return result
@ -283,13 +287,13 @@ def get_inputs_from_batched(dataset, count):
plot_predictions( plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train.png"), os.path.join(DIR_OUTPUT, "predict_train.png"),
get_inputs_from_batched(dataset_train, 4), get_from_batched(dataset_train, 4),
colormap, colormap,
model=model model=model
) )
plot_predictions( plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate.png"), os.path.join(DIR_OUTPUT, "predict_validate.png"),
get_inputs_from_batched(dataset_validate, 4), get_from_batched(dataset_validate, 4),
colormap, colormap,
model=model model=model
) )