prepare for NCE loss

.....but Tensorflow's implementation looks to be for supervised models :-(
This commit is contained in:
Starbeamrainbowlabs 2022-10-25 21:15:05 +01:00
parent bb0679a509
commit 98417a3e06
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 21 additions and 2 deletions

View file

@ -1,9 +1,10 @@
import math
import tensorflow as tf import tensorflow as tf
class LayerCheeseMultipleOut(tf.keras.layers.Layer): class LayerCheeseMultipleOut(tf.keras.layers.Layer):
def __init__(self, **kwargs): def __init__(self, batch_size, feature_dim, **kwargs):
"""Creates a new cheese multiple out layer. """Creates a new cheese multiple out layer.
This layer is useful if you have multiple outputs and a custom loss function that requires multiple inputs. This layer is useful if you have multiple outputs and a custom loss function that requires multiple inputs.
Basically, it just concatenates all inputs. Basically, it just concatenates all inputs.
@ -12,10 +13,28 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer):
""" """
super(LayerCheeseMultipleOut, self).__init__(**kwargs) super(LayerCheeseMultipleOut, self).__init__(**kwargs)
self.param_batch_size = batch_size
self.param_feature_dim = feature_dim
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07])) self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07]))
self.weight_nce = tf.Variable(
name="loss_nce",
shape=(batch_size, feature_dim),
initial_value=tf.random.truncated_normal(
(feature_dim),
stddev=1.0 / math.sqrt(128)
)
)
self.weight_nce_bias = tf.Variable(
name="loss_nce_bias",
shape=(feature_dim),
initial_value=tf.zeros((feature_dim))
)
def get_config(self): def get_config(self):
config = super(LayerCheeseMultipleOut, self).get_config() config = super(LayerCheeseMultipleOut, self).get_config()
config["batch_size"] = self.param_batch_size
config["feature_dim"] = self.param_feature_dim
return config return config
def call(self, inputs): def call(self, inputs):

View file

@ -49,7 +49,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
)(input_water) )(input_water)
layer_final = LayerCheeseMultipleOut() layer_final = LayerCheeseMultipleOut(batch_size=batch_size, feature_dim=feature_dim)
final = layer_final([ rainfall, water ]) final = layer_final([ rainfall, water ])
weight_temperature = layer_final.weight_temperature weight_temperature = layer_final.weight_temperature