mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
ai: fix arch_name plumbing
This commit is contained in:
parent
3d44831080
commit
f8ee0afca1
1 changed files with 5 additions and 2 deletions
|
@ -7,7 +7,7 @@ from .convnext import make_convnext
|
|||
|
||||
class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, input_width, input_height, channels, summary_file=None, feature_dim=2048, **kwargs):
|
||||
def __init__(self, input_width, input_height, channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs):
|
||||
"""Creates a new contrastive learning encoder layer.
|
||||
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
|
||||
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
|
||||
|
@ -30,13 +30,15 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
|||
self.param_input_height = input_height
|
||||
self.param_channels = channels
|
||||
self.param_feature_dim = feature_dim
|
||||
self.param_arch_name = arch_name
|
||||
|
||||
"""The main ConvNeXt model that forms the encoder.
|
||||
"""
|
||||
self.encoder = make_convnext(
|
||||
input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
|
||||
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
|
||||
num_classes = self.param_feature_dim # size of the feature dimension, see the line above this one
|
||||
num_classes = self.param_feature_dim, # size of the feature dimension, see the line above this one
|
||||
arch_name = self.param_arch_name
|
||||
)
|
||||
# """Small sequential stack of layers that control the size of the outputted feature dimension.
|
||||
# """
|
||||
|
@ -51,6 +53,7 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
|||
config["input_height"] = self.param_input_height
|
||||
config["input_channels"] = self.param_input_channels
|
||||
config["feature_dim"] = self.param_feature_dim
|
||||
config["arch_name"] = self.param_arch_name
|
||||
return config
|
||||
|
||||
# def build(self, input_shape):
|
||||
|
|
Loading…
Reference in a new issue