diff --git a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py index c74232e..0a1d4bd 100644 --- a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py +++ b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py @@ -7,7 +7,7 @@ from .convnext import make_convnext class LayerContrastiveEncoder(tf.keras.layers.Layer): - def __init__(self, input_width, input_height, channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs): + def __init__(self, input_width, input_height, input_channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs): """Creates a new contrastive learning encoder layer. Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me. While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer. @@ -24,18 +24,18 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer): print(f"input_width: {input_width}") print(f"input_height: {input_height}") - print(f"channels: {channels}") + print(f"channels: {input_channels}") - self.param_input_width = input_width - self.param_input_height = input_height - self.param_channels = channels - self.param_feature_dim = feature_dim - self.param_arch_name = arch_name + self.param_input_width = input_width + self.param_input_height = input_height + self.param_input_channels = input_channels + self.param_feature_dim = feature_dim + self.param_arch_name = arch_name """The main ConvNeXt model that forms the encoder. """ self.encoder = make_convnext( - input_shape = (self.param_input_width, self.param_input_height, self.param_channels), + input_shape = (self.param_input_width, self.param_input_height, self.param_input_channels), classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder num_classes = self.param_feature_dim, # size of the feature dimension, see the line above this one arch_name = self.param_arch_name diff --git a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py index 33a88ff..ab06f71 100644 --- a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py +++ b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py @@ -33,7 +33,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur rainfall = LayerContrastiveEncoder( input_width=rainfall_width, input_height=rainfall_height, - channels=rainfall_channels, + input_channels=rainfall_channels, feature_dim=feature_dim, summary_file=summary_file, arch_name="convnext_tiny", @@ -42,7 +42,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur water = LayerContrastiveEncoder( input_width=water_width, input_height=water_height, - channels=water_channels, + input_channels=water_channels, feature_dim=feature_dim, arch_name="convnext_xtiny", summary_file=summary_file