diff --git a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py index f607506..f36fa85 100644 --- a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py +++ b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py @@ -11,7 +11,7 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, logger.info(shape_water) # Shapes come from what rainfallwrangler sees them as, but we add an extra dimension when reading the .tfrecord file - rainfall_width, rainfall_height, rainfall_channels = shape_rainfall # shape = [width, height, channels] + rainfall_channels, rainfall_width, rainfall_height = shape_rainfall # shape = [channels, width, height] water_width, water_height = shape_water # shape = [width, height] water_channels = 1 # added in dataset → make_dataset → parse_item