contrastive: comment weights that aren't needed

This commit is contained in:
Starbeamrainbowlabs 2022-10-31 16:26:48 +00:00
parent 33391eaf16
commit 55dc05e8ce
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -19,19 +19,19 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer):
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([
math.log(1 / 0.07)
]))
self.weight_nce = tf.Variable(
name="loss_nce",
shape=(batch_size, feature_dim),
initial_value=tf.random.truncated_normal(
(feature_dim),
stddev=1.0 / math.sqrt(128)
)
)
self.weight_nce_bias = tf.Variable(
name="loss_nce_bias",
shape=(feature_dim),
initial_value=tf.zeros((feature_dim))
)
# self.weight_nce = tf.Variable(
# name="loss_nce",
# shape=(batch_size, feature_dim),
# initial_value=tf.random.truncated_normal(
# [feature_dim],
# stddev=1.0 / math.sqrt(128)
# )
# )
# self.weight_nce_bias = tf.Variable(
# name="loss_nce_bias",
# shape=(feature_dim),
# initial_value=tf.zeros((feature_dim))
# )
def get_config(self):
config = super(LayerCheeseMultipleOut, self).get_config()