diff --git a/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py b/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py index 87c2e43..21f976d 100644 --- a/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py +++ b/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py @@ -16,7 +16,9 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer): self.param_batch_size = batch_size self.param_feature_dim = feature_dim - self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07])) + self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([ + math.log(1 / 0.07) + ])) self.weight_nce = tf.Variable( name="loss_nce", shape=(batch_size, feature_dim),