diff --git a/aimodel/src/lib/ai/components/cbam.py b/aimodel/src/lib/ai/components/cbam.py index 4dcee83..cfdedeb 100644 --- a/aimodel/src/lib/ai/components/cbam.py +++ b/aimodel/src/lib/ai/components/cbam.py @@ -33,15 +33,19 @@ class LayerCBAMAttentionChannel(tf.keras.layers.Layer): super(LayerCBAMAttentionSpatial, self).__init__(**kwargs) self.param_dim = dim + self.param_reduction_ratio = reduction_ratio self.mlp = tf.keras.Sequential([ + tf.keras.layers.Dense(self.param_dim), + tf.keras.layers.Dense(self.param_dim / self.param_reduction_ratio), tf.keras.layers.Dense(self.param_dim) ]) def get_config(self): config = super(LayerCBAMAttentionSpatial, self).get_config() config.update({ - "dim": self.param_dim + "dim": self.param_dim, + "reduction_ratio": self.param_reduction_ratio }) return config