contrastive: comment weights that aren't needed

This commit is contained in:
Starbeamrainbowlabs 2022-10-31 16:26:48 +00:00
parent 33391eaf16
commit 55dc05e8ce
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -19,19 +19,19 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer):
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([ self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([
math.log(1 / 0.07) math.log(1 / 0.07)
])) ]))
self.weight_nce = tf.Variable( # self.weight_nce = tf.Variable(
name="loss_nce", # name="loss_nce",
shape=(batch_size, feature_dim), # shape=(batch_size, feature_dim),
initial_value=tf.random.truncated_normal( # initial_value=tf.random.truncated_normal(
(feature_dim), # [feature_dim],
stddev=1.0 / math.sqrt(128) # stddev=1.0 / math.sqrt(128)
) # )
) # )
self.weight_nce_bias = tf.Variable( # self.weight_nce_bias = tf.Variable(
name="loss_nce_bias", # name="loss_nce_bias",
shape=(feature_dim), # shape=(feature_dim),
initial_value=tf.zeros((feature_dim)) # initial_value=tf.zeros((feature_dim))
) # )
def get_config(self): def get_config(self):
config = super(LayerCheeseMultipleOut, self).get_config() config = super(LayerCheeseMultipleOut, self).get_config()