From 36859746ff28cbd821303a502191db177e2dc2b7 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 6 Jan 2023 17:13:35 +0000 Subject: [PATCH] dlr: fix crash --- aimodel/src/deeplabv3_plus_test_rainfall.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 86928f0..23aeef2 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -249,10 +249,10 @@ def plot_predictions(filepath, input_items, colormap, model): figsize=(18, 14) ) -def get_items_from_batched(dataset, count): +def get_inputs_from_batched(dataset, count): result = [] for batched in dataset: - items = tf.unstack(batched, axis=0) + items = tf.unstack(batched[0], axis=0) for item in items: result.append(item) if len(result) >= count: @@ -261,13 +261,13 @@ def get_items_from_batched(dataset, count): plot_predictions( os.path.join(DIR_OUTPUT, "predict_train.png"), - get_items_from_batched(dataset_train, 4), + get_inputs_from_batched(dataset_train, 4), colormap, model=model ) plot_predictions( os.path.join(DIR_OUTPUT, "predict_validate.png"), - get_items_from_batched(dataset_validate, 4), + get_inputs_from_batched(dataset_validate, 4), colormap, model=model )