ai Bugfix LayerContrastiveEncoder: channels → input_channels

for consistency
This commit is contained in:
Starbeamrainbowlabs 2022-09-05 23:53:16 +01:00
parent ead8009425
commit 3e13ad12c8
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 10 additions and 10 deletions

View file

@ -7,7 +7,7 @@ from .convnext import make_convnext
class LayerContrastiveEncoder(tf.keras.layers.Layer):
def __init__(self, input_width, input_height, channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs):
def __init__(self, input_width, input_height, input_channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs):
"""Creates a new contrastive learning encoder layer.
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
@ -24,18 +24,18 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
print(f"input_width: {input_width}")
print(f"input_height: {input_height}")
print(f"channels: {channels}")
print(f"channels: {input_channels}")
self.param_input_width = input_width
self.param_input_height = input_height
self.param_channels = channels
self.param_feature_dim = feature_dim
self.param_arch_name = arch_name
self.param_input_width = input_width
self.param_input_height = input_height
self.param_input_channels = input_channels
self.param_feature_dim = feature_dim
self.param_arch_name = arch_name
"""The main ConvNeXt model that forms the encoder.
"""
self.encoder = make_convnext(
input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
input_shape = (self.param_input_width, self.param_input_height, self.param_input_channels),
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
num_classes = self.param_feature_dim, # size of the feature dimension, see the line above this one
arch_name = self.param_arch_name

View file

@ -33,7 +33,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
rainfall = LayerContrastiveEncoder(
input_width=rainfall_width,
input_height=rainfall_height,
channels=rainfall_channels,
input_channels=rainfall_channels,
feature_dim=feature_dim,
summary_file=summary_file,
arch_name="convnext_tiny",
@ -42,7 +42,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
water = LayerContrastiveEncoder(
input_width=water_width,
input_height=water_height,
channels=water_channels,
input_channels=water_channels,
feature_dim=feature_dim,
arch_name="convnext_xtiny",
summary_file=summary_file