This commit is contained in:
Starbeamrainbowlabs 2023-02-03 16:01:54 +00:00
parent 1a8f10339a
commit 8446a842d1
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -7,7 +7,7 @@ class LayerConvNeXtGamma(tf.keras.layers.Layer):
super(LayerConvNeXtGamma, self).__init__(name=name)
self.dim = dim
self.const = const_val * tf.ones((self.dim), dtype=tf.float32 if tf.mixed_precision.global_policy().name == "float32" else tf.float16)
self.const = const_val * tf.ones((self.dim), dtype=tf.float32 if tf.keras.mixed_precision.global_policy().name == "float32" else tf.float16)
def call(self, inputs, **kwargs):
return tf.multiply(inputs, self.const)