ai: global? really?

This commit is contained in:
Starbeamrainbowlabs 2022-09-01 16:06:24 +01:00
parent 4952ead094
commit cfbbe8e8cf
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -21,7 +21,7 @@ depths_dims = dict(
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])), convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
) )
next_model_number = 0 __convnext_next_model_number = 0
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs): def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
"""Makes a ConvNeXt model. """Makes a ConvNeXt model.
@ -30,18 +30,18 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
input_shape (int[]): The input shape of the tensor that will be fed to the ConvNeXt model. This is necessary as we make the model using the functional API and thus we need to make an Input layer. input_shape (int[]): The input shape of the tensor that will be fed to the ConvNeXt model. This is necessary as we make the model using the functional API and thus we need to make an Input layer.
arch_name (str, optional): The name of the preset ConvNeXt model architecture to use. Defaults to "convnext_tiny". arch_name (str, optional): The name of the preset ConvNeXt model architecture to use. Defaults to "convnext_tiny".
""" """
nonlocal next_model_number global __convnext_next_model_number
layer_in = tf.keras.layers.Input( layer_in = tf.keras.layers.Input(
shape = input_shape shape = input_shape
) )
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs) layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
result = tf.keras.Model( result = tf.keras.Model(
name=f"convnext{next_model_number}", name=f"convnext{__convnext_next_model_number}",
inputs = layer_in, inputs = layer_in,
outputs = layer_out outputs = layer_out
) )
next_model_number += 1 __convnext_next_model_number += 1
return result return result