diff --git a/aimodel/src/lib/ai/RainfallWaterContraster.py b/aimodel/src/lib/ai/RainfallWaterContraster.py index c871949..b452799 100644 --- a/aimodel/src/lib/ai/RainfallWaterContraster.py +++ b/aimodel/src/lib/ai/RainfallWaterContraster.py @@ -57,7 +57,7 @@ class RainfallWaterContraster(object): def make_model(self): - model = model_rainfallwater_contrastive(**self.kwargs) + model = model_rainfallwater_contrastive(batch_size=self.batch_size, **self.kwargs) return model diff --git a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py index ddcc83b..23b7a75 100644 --- a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py +++ b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py @@ -6,7 +6,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut from .components.LossContrastive import LossContrastive -def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=2048): +def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048): logger.info(shape_rainfall) logger.info(shape_water) @@ -48,5 +48,5 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=204 model.compile( optimizer="Adam", - loss=LossContrastive(weight_temperature=weight_temperature) + loss=LossContrastive(batch_size=batch_size, weight_temperature=weight_temperature) ) \ No newline at end of file