From 6423bf6702fe62763f197d70fce630ae9c232eda Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 12 Oct 2022 17:12:07 +0100 Subject: [PATCH] LayerConvNeXtGamma: avoid adding an EagerTensor to config Very weird how this is a problem when it wasn't before.. --- .../lib/ai/components/LayerConvNeXtGamma.py | 27 +++++++++---------- aimodel/src/lib/ai/helpers/make_callbacks.py | 2 +- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/aimodel/src/lib/ai/components/LayerConvNeXtGamma.py b/aimodel/src/lib/ai/components/LayerConvNeXtGamma.py index f0b8ec0..2396f86 100644 --- a/aimodel/src/lib/ai/components/LayerConvNeXtGamma.py +++ b/aimodel/src/lib/ai/components/LayerConvNeXtGamma.py @@ -3,18 +3,17 @@ import tensorflow as tf # Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb class LayerConvNeXtGamma(tf.keras.layers.Layer): - def __init__(self, const_val = 1e-6, dim = None, name=None, **kwargs): - super(LayerConvNeXtGamma, self).__init__(name=name) - - self.dim = dim - self.const = const_val * tf.ones((self.dim)) + def __init__(self, const_val = 1e-6, dim = None, name=None, **kwargs): + super(LayerConvNeXtGamma, self).__init__(name=name) + + self.dim = dim + self.const = const_val * tf.ones((self.dim)) - def call(self, inputs, **kwargs): - return tf.multiply(inputs, self.const) - - def get_config(self): - config = super(LayerConvNeXtGamma, self).get_config() - - config.update({ "const": self.const, "dim": self.dim }) - - return config + def call(self, inputs, **kwargs): + return tf.multiply(inputs, self.const) + + def get_config(self): + config = super(LayerConvNeXtGamma, self).get_config() + + config.update({ "const": self.const.numpy(), "dim": self.dim }) + return config diff --git a/aimodel/src/lib/ai/helpers/make_callbacks.py b/aimodel/src/lib/ai/helpers/make_callbacks.py index 98a6545..bfb112a 100644 --- a/aimodel/src/lib/ai/helpers/make_callbacks.py +++ b/aimodel/src/lib/ai/helpers/make_callbacks.py @@ -16,7 +16,7 @@ def make_callbacks(dirpath, model_predict): model_to_checkpoint=model_predict, filepath=os.path.join( dirpath_checkpoints, - "checkpoint_weights_e{epoch:d}_loss{loss:.3f}.hdf5" + "checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5" ), monitor="loss" ),