mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
ai: global? really?
This commit is contained in:
parent
4952ead094
commit
cfbbe8e8cf
1 changed files with 4 additions and 4 deletions
|
@ -21,7 +21,7 @@ depths_dims = dict(
|
|||
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
|
||||
)
|
||||
|
||||
next_model_number = 0
|
||||
__convnext_next_model_number = 0
|
||||
|
||||
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
||||
"""Makes a ConvNeXt model.
|
||||
|
@ -30,18 +30,18 @@ def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
|||
input_shape (int[]): The input shape of the tensor that will be fed to the ConvNeXt model. This is necessary as we make the model using the functional API and thus we need to make an Input layer.
|
||||
arch_name (str, optional): The name of the preset ConvNeXt model architecture to use. Defaults to "convnext_tiny".
|
||||
"""
|
||||
nonlocal next_model_number
|
||||
global __convnext_next_model_number
|
||||
|
||||
layer_in = tf.keras.layers.Input(
|
||||
shape = input_shape
|
||||
)
|
||||
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
|
||||
result = tf.keras.Model(
|
||||
name=f"convnext{next_model_number}",
|
||||
name=f"convnext{__convnext_next_model_number}",
|
||||
inputs = layer_in,
|
||||
outputs = layer_out
|
||||
)
|
||||
next_model_number += 1
|
||||
__convnext_next_model_number += 1
|
||||
return result
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue