diff --git a/aimodel/src/lib/ai/components/LossContrastive.py b/aimodel/src/lib/ai/components/LossContrastive.py index 697e517..59a2f8f 100644 --- a/aimodel/src/lib/ai/components/LossContrastive.py +++ b/aimodel/src/lib/ai/components/LossContrastive.py @@ -17,6 +17,10 @@ class LossContrastive(tf.keras.losses.Loss): # rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]]) # water = tf.reshape(water, [self.batch_size, water.shape[1]]) + # normalise features + rainfall = rainfall / tf.math.l2_normalize(rainfall, axis=1) + rainfall = rainfall / tf.math.l2_normalize(rainfall, axis=1) + # logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100) logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.math.exp(self.weight_temperature)