From c0c6e81c01ad2f27bea9f8babde8bbdcd00a8fd2 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 12 Jan 2023 18:12:50 +0000 Subject: [PATCH] dlr predict: allow for multiple outputs --- aimodel/src/deeplabv3_plus_test_rainfall.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index c3080c5..62e2b43 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -256,14 +256,16 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)): def plot_predictions(filepath, input_items, colormap, model): + i = 0 for input_pair in input_items: prediction_mask = infer(image_tensor=input_pair[0], model=model) # label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2) prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2) + print("DEBUG:plot_predictions INFER", prediction_mask) plot_samples_matplotlib( - filepath, + filepath.replace("$$", i), [ # input_tensor, input_pair[1], #label_colourmap @@ -271,6 +273,7 @@ def plot_predictions(filepath, input_items, colormap, model): ], figsize=(18, 14) ) + i += 1 def get_from_batched(dataset, count): result = [] @@ -284,13 +287,13 @@ def get_from_batched(dataset, count): plot_predictions( - os.path.join(DIR_OUTPUT, "predict_train.png"), + os.path.join(DIR_OUTPUT, "predict_train_$$.png"), get_from_batched(dataset_train, 4), colormap, model=model ) plot_predictions( - os.path.join(DIR_OUTPUT, "predict_validate.png"), + os.path.join(DIR_OUTPUT, "predict_validate_$$.png"), get_from_batched(dataset_validate, 4), colormap, model=model