ai: fix arch_name plumbing

This commit is contained in:
Starbeamrainbowlabs 2022-09-02 17:57:07 +01:00
parent 3d44831080
commit f8ee0afca1
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -7,7 +7,7 @@ from .convnext import make_convnext
class LayerContrastiveEncoder(tf.keras.layers.Layer):
def __init__(self, input_width, input_height, channels, summary_file=None, feature_dim=2048, **kwargs):
def __init__(self, input_width, input_height, channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs):
"""Creates a new contrastive learning encoder layer.
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
@ -30,13 +30,15 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
self.param_input_height = input_height
self.param_channels = channels
self.param_feature_dim = feature_dim
self.param_arch_name = arch_name
"""The main ConvNeXt model that forms the encoder.
"""
self.encoder = make_convnext(
input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
num_classes = self.param_feature_dim # size of the feature dimension, see the line above this one
num_classes = self.param_feature_dim, # size of the feature dimension, see the line above this one
arch_name = self.param_arch_name
)
# """Small sequential stack of layers that control the size of the outputted feature dimension.
# """
@ -51,6 +53,7 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
config["input_height"] = self.param_input_height
config["input_channels"] = self.param_input_channels
config["feature_dim"] = self.param_feature_dim
config["arch_name"] = self.param_arch_name
return config
# def build(self, input_shape):