mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
ai: add missing gamma layer
This commit is contained in:
parent
51cf08a386
commit
e4edc68df5
1 changed files with 20 additions and 0 deletions
20
aimodel/src/lib/ai/components/LayerConvNeXtGamma.py
Normal file
20
aimodel/src/lib/ai/components/LayerConvNeXtGamma.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
# Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb
|
||||||
|
|
||||||
|
class LayerConvNeXtGamma(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, const_val = 1e-6, dim = None, name=None, **kwargs):
|
||||||
|
super(LayerConvNeXtGamma, self).__init__(name=name)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.const = const_val * tf.ones((self.dim))
|
||||||
|
|
||||||
|
def call(self, inputs, **kwargs):
|
||||||
|
return tf.multiply(inputs, self.const)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super(LayerConvNeXtGamma, self).get_config()
|
||||||
|
|
||||||
|
config.update({ "const": self.const, "dim": self.dim })
|
||||||
|
|
||||||
|
return config
|
Loading…
Reference in a new issue