mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
ai: fix 'nother crash' name ConvNeXt submodels
This commit is contained in:
parent
b2a320134e
commit
8bdded23eb
2 changed files with 10 additions and 5 deletions
|
@ -47,10 +47,10 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
|||
config["feature_dim"] = self.param_feature_dim
|
||||
return config
|
||||
|
||||
def build(self, input_shape):
|
||||
# print("LAYER:build input_shape", input_shape)
|
||||
super().build(input_shape=input_shape[0])
|
||||
self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
|
||||
# def build(self, input_shape):
|
||||
# # print("LAYER:build input_shape", input_shape)
|
||||
# super().build(input_shape=input_shape[0])
|
||||
# self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
|
||||
|
||||
def call(self, input_thing):
|
||||
result = self.encoder(input_thing)
|
||||
|
|
|
@ -21,6 +21,8 @@ depths_dims = dict(
|
|||
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
|
||||
)
|
||||
|
||||
next_model_number = 0
|
||||
|
||||
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
||||
"""Makes a ConvNeXt model.
|
||||
Returns a tf.keras.Model.
|
||||
|
@ -32,10 +34,13 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
|||
shape = input_shape
|
||||
)
|
||||
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
|
||||
return tf.keras.Model(
|
||||
result = tf.keras.Model(
|
||||
name=f"convnext{next_model_number}",
|
||||
inputs = layer_in,
|
||||
outputs = layer_out
|
||||
)
|
||||
next_model_number += 1
|
||||
return result
|
||||
|
||||
|
||||
def convnext(
|
||||
|
|
Loading…
Reference in a new issue