handle feature_dim properly

This commit is contained in:
Starbeamrainbowlabs 2022-08-31 17:41:51 +01:00
parent c52a9f961c
commit 12c77e128d
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -5,7 +5,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
from .components.LossContrastive import LossContrastive
def model_rainfallwater_contrastive(shape_rainfall, shape_water):
def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=200):
rainfall_width, rainfall_height, rainfall_channels = shape_rainfall
water_width, water_height, water_channels = shape_water
@ -20,12 +20,14 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water):
rainfall = LayerContrastiveEncoder(
input_width=rainfall_width,
input_height=rainfall_height,
channels=rainfall_channels
channels=rainfall_channels,
feature_dim=feature_dim
)(input_rainfall)
water = LayerContrastiveEncoder(
input_width=water_width,
input_height=water_height,
channels=water_channels
channels=water_channels,
feature_dim=feature_dim
)(input_water)