diff --git a/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py b/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py index 21f976d..3f9481a 100644 --- a/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py +++ b/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py @@ -19,19 +19,19 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer): self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([ math.log(1 / 0.07) ])) - self.weight_nce = tf.Variable( - name="loss_nce", - shape=(batch_size, feature_dim), - initial_value=tf.random.truncated_normal( - (feature_dim), - stddev=1.0 / math.sqrt(128) - ) - ) - self.weight_nce_bias = tf.Variable( - name="loss_nce_bias", - shape=(feature_dim), - initial_value=tf.zeros((feature_dim)) - ) + # self.weight_nce = tf.Variable( + # name="loss_nce", + # shape=(batch_size, feature_dim), + # initial_value=tf.random.truncated_normal( + # [feature_dim], + # stddev=1.0 / math.sqrt(128) + # ) + # ) + # self.weight_nce_bias = tf.Variable( + # name="loss_nce_bias", + # shape=(feature_dim), + # initial_value=tf.zeros((feature_dim)) + # ) def get_config(self): config = super(LayerCheeseMultipleOut, self).get_config()