From 12c77e128d570cc4ed2b7f0cb518efb6d109a82c Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 31 Aug 2022 17:41:51 +0100 Subject: [PATCH] handle feature_dim properly --- aimodel/src/lib/ai/model_rainfallwater_contrastive.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py index 97ea631..2bff097 100644 --- a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py +++ b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py @@ -5,7 +5,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut from .components.LossContrastive import LossContrastive -def model_rainfallwater_contrastive(shape_rainfall, shape_water): +def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=200): rainfall_width, rainfall_height, rainfall_channels = shape_rainfall water_width, water_height, water_channels = shape_water @@ -20,12 +20,14 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water): rainfall = LayerContrastiveEncoder( input_width=rainfall_width, input_height=rainfall_height, - channels=rainfall_channels + channels=rainfall_channels, + feature_dim=feature_dim )(input_rainfall) water = LayerContrastiveEncoder( input_width=water_width, input_height=water_height, - channels=water_channels + channels=water_channels, + feature_dim=feature_dim )(input_water)