mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
ai: fix arch_name plumbing
This commit is contained in:
parent
3d44831080
commit
f8ee0afca1
1 changed files with 5 additions and 2 deletions
|
@ -7,7 +7,7 @@ from .convnext import make_convnext
|
||||||
|
|
||||||
class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(self, input_width, input_height, channels, summary_file=None, feature_dim=2048, **kwargs):
|
def __init__(self, input_width, input_height, channels, arch_name="convnext_tiny", summary_file=None, feature_dim=2048, **kwargs):
|
||||||
"""Creates a new contrastive learning encoder layer.
|
"""Creates a new contrastive learning encoder layer.
|
||||||
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
|
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
|
||||||
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
|
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
|
||||||
|
@ -30,13 +30,15 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||||
self.param_input_height = input_height
|
self.param_input_height = input_height
|
||||||
self.param_channels = channels
|
self.param_channels = channels
|
||||||
self.param_feature_dim = feature_dim
|
self.param_feature_dim = feature_dim
|
||||||
|
self.param_arch_name = arch_name
|
||||||
|
|
||||||
"""The main ConvNeXt model that forms the encoder.
|
"""The main ConvNeXt model that forms the encoder.
|
||||||
"""
|
"""
|
||||||
self.encoder = make_convnext(
|
self.encoder = make_convnext(
|
||||||
input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
|
input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
|
||||||
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
|
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
|
||||||
num_classes = self.param_feature_dim # size of the feature dimension, see the line above this one
|
num_classes = self.param_feature_dim, # size of the feature dimension, see the line above this one
|
||||||
|
arch_name = self.param_arch_name
|
||||||
)
|
)
|
||||||
# """Small sequential stack of layers that control the size of the outputted feature dimension.
|
# """Small sequential stack of layers that control the size of the outputted feature dimension.
|
||||||
# """
|
# """
|
||||||
|
@ -51,6 +53,7 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||||
config["input_height"] = self.param_input_height
|
config["input_height"] = self.param_input_height
|
||||||
config["input_channels"] = self.param_input_channels
|
config["input_channels"] = self.param_input_channels
|
||||||
config["feature_dim"] = self.param_feature_dim
|
config["feature_dim"] = self.param_feature_dim
|
||||||
|
config["arch_name"] = self.param_arch_name
|
||||||
return config
|
return config
|
||||||
|
|
||||||
# def build(self, input_shape):
|
# def build(self, input_shape):
|
||||||
|
|
Loading…
Reference in a new issue