From 843cc8dc7bdc54c4435ce56a8154c627e00146ae Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 26 Oct 2022 16:45:45 +0100 Subject: [PATCH] contrastive: rewrite the loss function. The CLIP paper *does* kinda make sense I think --- aimodel/src/lib/ai/components/LossContrastive.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/aimodel/src/lib/ai/components/LossContrastive.py b/aimodel/src/lib/ai/components/LossContrastive.py index 43e6e74..697e517 100644 --- a/aimodel/src/lib/ai/components/LossContrastive.py +++ b/aimodel/src/lib/ai/components/LossContrastive.py @@ -17,14 +17,17 @@ class LossContrastive(tf.keras.losses.Loss): # rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]]) # water = tf.reshape(water, [self.batch_size, water.shape[1]]) - - logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100) + # logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100) + logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.math.exp(self.weight_temperature) # print("LOGITS", logits) - labels = tf.eye(self.batch_size, dtype=tf.int32) - loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0) - loss_water = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=1) + # labels = tf.eye(self.batch_size, dtype=tf.int32) # we *would* do this if we were using mean squared error... + labels = tf.range(self.batch_size, dtype=tf.int32) # each row is a different category we think + loss_rainfall = tf.keras.metrics.sparse_categorical_crossentropy(labels, logits, from_logits=True, axis=0) + loss_water = tf.keras.metrics.sparse_categorical_crossentropy(labels, logits, from_logits=True, axis=1) + # loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0) + # loss_water = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=1) loss = (loss_rainfall + loss_water) / 2