dlr predict: allow for multiple outputs

This commit is contained in:
Starbeamrainbowlabs 2023-01-12 18:12:50 +00:00
parent 864dfa802d
commit c0c6e81c01
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -256,14 +256,16 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
def plot_predictions(filepath, input_items, colormap, model):
i = 0
for input_pair in input_items:
prediction_mask = infer(image_tensor=input_pair[0], model=model)
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
print("DEBUG:plot_predictions INFER", prediction_mask)
plot_samples_matplotlib(
filepath,
filepath.replace("$$", i),
[
# input_tensor,
input_pair[1], #label_colourmap
@ -271,6 +273,7 @@ def plot_predictions(filepath, input_items, colormap, model):
],
figsize=(18, 14)
)
i += 1
def get_from_batched(dataset, count):
result = []
@ -284,13 +287,13 @@ def get_from_batched(dataset, count):
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train.png"),
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
get_from_batched(dataset_train, 4),
colormap,
model=model
)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate.png"),
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
get_from_batched(dataset_validate, 4),
colormap,
model=model