mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
contrastive: comment weights that aren't needed
This commit is contained in:
parent
33391eaf16
commit
55dc05e8ce
1 changed files with 13 additions and 13 deletions
|
@ -19,19 +19,19 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer):
|
||||||
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([
|
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([
|
||||||
math.log(1 / 0.07)
|
math.log(1 / 0.07)
|
||||||
]))
|
]))
|
||||||
self.weight_nce = tf.Variable(
|
# self.weight_nce = tf.Variable(
|
||||||
name="loss_nce",
|
# name="loss_nce",
|
||||||
shape=(batch_size, feature_dim),
|
# shape=(batch_size, feature_dim),
|
||||||
initial_value=tf.random.truncated_normal(
|
# initial_value=tf.random.truncated_normal(
|
||||||
(feature_dim),
|
# [feature_dim],
|
||||||
stddev=1.0 / math.sqrt(128)
|
# stddev=1.0 / math.sqrt(128)
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
self.weight_nce_bias = tf.Variable(
|
# self.weight_nce_bias = tf.Variable(
|
||||||
name="loss_nce_bias",
|
# name="loss_nce_bias",
|
||||||
shape=(feature_dim),
|
# shape=(feature_dim),
|
||||||
initial_value=tf.zeros((feature_dim))
|
# initial_value=tf.zeros((feature_dim))
|
||||||
)
|
# )
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super(LayerCheeseMultipleOut, self).get_config()
|
config = super(LayerCheeseMultipleOut, self).get_config()
|
||||||
|
|
Loading…
Reference in a new issue