mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
ai: fix 'nother crash' name ConvNeXt submodels
This commit is contained in:
parent
b2a320134e
commit
8bdded23eb
2 changed files with 10 additions and 5 deletions
|
@ -47,10 +47,10 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||||
config["feature_dim"] = self.param_feature_dim
|
config["feature_dim"] = self.param_feature_dim
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def build(self, input_shape):
|
# def build(self, input_shape):
|
||||||
# print("LAYER:build input_shape", input_shape)
|
# # print("LAYER:build input_shape", input_shape)
|
||||||
super().build(input_shape=input_shape[0])
|
# super().build(input_shape=input_shape[0])
|
||||||
self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
|
# self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
|
||||||
|
|
||||||
def call(self, input_thing):
|
def call(self, input_thing):
|
||||||
result = self.encoder(input_thing)
|
result = self.encoder(input_thing)
|
||||||
|
|
|
@ -21,6 +21,8 @@ depths_dims = dict(
|
||||||
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
|
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_model_number = 0
|
||||||
|
|
||||||
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
||||||
"""Makes a ConvNeXt model.
|
"""Makes a ConvNeXt model.
|
||||||
Returns a tf.keras.Model.
|
Returns a tf.keras.Model.
|
||||||
|
@ -32,10 +34,13 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
||||||
shape = input_shape
|
shape = input_shape
|
||||||
)
|
)
|
||||||
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
|
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
|
||||||
return tf.keras.Model(
|
result = tf.keras.Model(
|
||||||
|
name=f"convnext{next_model_number}",
|
||||||
inputs = layer_in,
|
inputs = layer_in,
|
||||||
outputs = layer_out
|
outputs = layer_out
|
||||||
)
|
)
|
||||||
|
next_model_number += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def convnext(
|
def convnext(
|
||||||
|
|
Loading…
Reference in a new issue