diff --git a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py index 6c780e7..524eff4 100644 --- a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py +++ b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py @@ -7,7 +7,7 @@ from .convnext import make_convnext class LayerContrastiveEncoder(tf.keras.layers.Layer): - def __init__(self, input_width, input_height, channels, summary_file=None, feature_dim=2048, **kwargs): + def __init__(self, input_width, input_height, channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs): """Creates a new contrastive learning encoder layer. Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me. While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer. @@ -30,13 +30,15 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer): self.param_input_height = input_height self.param_channels = channels self.param_feature_dim = feature_dim + self.param_arch_name = arch_name """The main ConvNeXt model that forms the encoder. """ self.encoder = make_convnext( input_shape = (self.param_input_width, self.param_input_height, self.param_channels), classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder - num_classes = self.param_feature_dim # size of the feature dimension, see the line above this one + num_classes = self.param_feature_dim, # size of the feature dimension, see the line above this one + arch_name = self.param_arch_name ) # """Small sequential stack of layers that control the size of the outputted feature dimension. # """ @@ -51,6 +53,7 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer): config["input_height"] = self.param_input_height config["input_channels"] = self.param_input_channels config["feature_dim"] = self.param_feature_dim + config["arch_name"] = self.param_arch_name return config # def build(self, input_shape):