contrastive: rewrite the loss function.

The CLIP paper *does* kinda make sense I think
This commit is contained in:
Starbeamrainbowlabs 2022-10-26 16:45:45 +01:00
parent fad1399c2d
commit 843cc8dc7b
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -17,14 +17,17 @@ class LossContrastive(tf.keras.losses.Loss):
# rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]])
# water = tf.reshape(water, [self.batch_size, water.shape[1]])
logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100)
# logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100)
logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.math.exp(self.weight_temperature)
# print("LOGITS", logits)
labels = tf.eye(self.batch_size, dtype=tf.int32)
loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0)
loss_water = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=1)
# labels = tf.eye(self.batch_size, dtype=tf.int32) # we *would* do this if we were using mean squared error...
labels = tf.range(self.batch_size, dtype=tf.int32) # each row is a different category we think
loss_rainfall = tf.keras.metrics.sparse_categorical_crossentropy(labels, logits, from_logits=True, axis=0)
loss_water = tf.keras.metrics.sparse_categorical_crossentropy(labels, logits, from_logits=True, axis=1)
# loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0)
# loss_water = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=1)
loss = (loss_rainfall + loss_water) / 2