mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
CBAM: unsure if it's 1 ro 3 dense ayers in the shared mlp
This commit is contained in:
parent
62f6a993bb
commit
cad82cd1bc
1 changed files with 5 additions and 1 deletions
|
@ -33,15 +33,19 @@ class LayerCBAMAttentionChannel(tf.keras.layers.Layer):
|
||||||
super(LayerCBAMAttentionSpatial, self).__init__(**kwargs)
|
super(LayerCBAMAttentionSpatial, self).__init__(**kwargs)
|
||||||
|
|
||||||
self.param_dim = dim
|
self.param_dim = dim
|
||||||
|
self.param_reduction_ratio = reduction_ratio
|
||||||
|
|
||||||
self.mlp = tf.keras.Sequential([
|
self.mlp = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.Dense(self.param_dim),
|
||||||
|
tf.keras.layers.Dense(self.param_dim / self.param_reduction_ratio),
|
||||||
tf.keras.layers.Dense(self.param_dim)
|
tf.keras.layers.Dense(self.param_dim)
|
||||||
])
|
])
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super(LayerCBAMAttentionSpatial, self).get_config()
|
config = super(LayerCBAMAttentionSpatial, self).get_config()
|
||||||
config.update({
|
config.update({
|
||||||
"dim": self.param_dim
|
"dim": self.param_dim,
|
||||||
|
"reduction_ratio": self.param_reduction_ratio
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue