ai: global? really?

This commit is contained in:
Starbeamrainbowlabs 2022-09-01 16:06:24 +01:00
parent 4952ead094
commit cfbbe8e8cf
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -21,7 +21,7 @@ depths_dims = dict(
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
)
next_model_number = 0
__convnext_next_model_number = 0
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
"""Makes a ConvNeXt model.
@ -30,18 +30,18 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
input_shape (int[]): The input shape of the tensor that will be fed to the ConvNeXt model. This is necessary as we make the model using the functional API and thus we need to make an Input layer.
arch_name (str, optional): The name of the preset ConvNeXt model architecture to use. Defaults to "convnext_tiny".
"""
nonlocal next_model_number
global __convnext_next_model_number
layer_in = tf.keras.layers.Input(
shape = input_shape
)
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
result = tf.keras.Model(
name=f"convnext{next_model_number}",
name=f"convnext{__convnext_next_model_number}",
inputs = layer_in,
outputs = layer_out
)
next_model_number += 1
__convnext_next_model_number += 1
return result