diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 1be2807..59f58b7 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -258,24 +258,28 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)): def plot_predictions(filepath, input_items, colormap, model): - for input_tensor in input_items: - prediction_mask = infer(image_tensor=input_tensor, model=model) - prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20) + for input_pair in input_items: + prediction_mask = infer(image_tensor=input_pair[0], model=model) + # label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2) + prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2) + plot_samples_matplotlib( filepath, [ # input_tensor, + input_items[1], #label_colourmap prediction_colormap ], figsize=(18, 14) ) -def get_inputs_from_batched(dataset, count): +def get_from_batched(dataset, count): result = [] for batched in dataset: - items = tf.unstack(batched[0], axis=0) - for item in items: + items_input = tf.unstack(batched[0], axis=0) + items_label = tf.unstack(batched[1], axis=0) + for item in zip(items_input, items_label): result.append(item) if len(result) >= count: return result @@ -283,13 +287,13 @@ def get_inputs_from_batched(dataset, count): plot_predictions( os.path.join(DIR_OUTPUT, "predict_train.png"), - get_inputs_from_batched(dataset_train, 4), + get_from_batched(dataset_train, 4), colormap, model=model ) plot_predictions( os.path.join(DIR_OUTPUT, "predict_validate.png"), - get_inputs_from_batched(dataset_validate, 4), + get_from_batched(dataset_validate, 4), colormap, model=model )