ai: fix 'nother crash' name ConvNeXt submodels

This commit is contained in:
Starbeamrainbowlabs 2022-08-31 18:57:27 +01:00
parent b2a320134e
commit 8bdded23eb
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 10 additions and 5 deletions

View file

@ -47,10 +47,10 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
config["feature_dim"] = self.param_feature_dim config["feature_dim"] = self.param_feature_dim
return config return config
def build(self, input_shape): # def build(self, input_shape):
# print("LAYER:build input_shape", input_shape) # # print("LAYER:build input_shape", input_shape)
super().build(input_shape=input_shape[0]) # super().build(input_shape=input_shape[0])
self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ])) # self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
def call(self, input_thing): def call(self, input_thing):
result = self.encoder(input_thing) result = self.encoder(input_thing)

View file

@ -21,6 +21,8 @@ depths_dims = dict(
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])), convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
) )
next_model_number = 0
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs): def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
"""Makes a ConvNeXt model. """Makes a ConvNeXt model.
Returns a tf.keras.Model. Returns a tf.keras.Model.
@ -32,10 +34,13 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
shape = input_shape shape = input_shape
) )
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs) layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
return tf.keras.Model( result = tf.keras.Model(
name=f"convnext{next_model_number}",
inputs = layer_in, inputs = layer_in,
outputs = layer_out outputs = layer_out
) )
next_model_number += 1
return result
def convnext( def convnext(