diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 86928f0..23aeef2 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -249,10 +249,10 @@ def plot_predictions(filepath, input_items, colormap, model): figsize=(18, 14) ) -def get_items_from_batched(dataset, count): +def get_inputs_from_batched(dataset, count): result = [] for batched in dataset: - items = tf.unstack(batched, axis=0) + items = tf.unstack(batched[0], axis=0) for item in items: result.append(item) if len(result) >= count: @@ -261,13 +261,13 @@ def get_items_from_batched(dataset, count): plot_predictions( os.path.join(DIR_OUTPUT, "predict_train.png"), - get_items_from_batched(dataset_train, 4), + get_inputs_from_batched(dataset_train, 4), colormap, model=model ) plot_predictions( os.path.join(DIR_OUTPUT, "predict_validate.png"), - get_items_from_batched(dataset_validate, 4), + get_inputs_from_batched(dataset_validate, 4), colormap, model=model )