From 8bdded23eb6466f401bf0fc31446a5796dcacd12 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 31 Aug 2022 18:57:27 +0100 Subject: [PATCH] ai: fix 'nother crash' name ConvNeXt submodels --- aimodel/src/lib/ai/components/LayerContrastiveEncoder.py | 8 ++++---- aimodel/src/lib/ai/components/convnext.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py index d3c34a6..2a2d94c 100644 --- a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py +++ b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py @@ -47,10 +47,10 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer): config["feature_dim"] = self.param_feature_dim return config - def build(self, input_shape): - # print("LAYER:build input_shape", input_shape) - super().build(input_shape=input_shape[0]) - self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ])) + # def build(self, input_shape): + # # print("LAYER:build input_shape", input_shape) + # super().build(input_shape=input_shape[0]) + # self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ])) def call(self, input_thing): result = self.encoder(input_thing) diff --git a/aimodel/src/lib/ai/components/convnext.py b/aimodel/src/lib/ai/components/convnext.py index f628498..7ab29e3 100644 --- a/aimodel/src/lib/ai/components/convnext.py +++ b/aimodel/src/lib/ai/components/convnext.py @@ -21,6 +21,8 @@ depths_dims = dict( convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])), ) +next_model_number = 0 + def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs): """Makes a ConvNeXt model. Returns a tf.keras.Model. @@ -32,10 +34,13 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs): shape = input_shape ) layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs) - return tf.keras.Model( + result = tf.keras.Model( + name=f"convnext{next_model_number}", inputs = layer_in, outputs = layer_out ) + next_model_number += 1 + return result def convnext(