From 98417a3e069f09bbd3d790bbb0b99f71eb9f4be2 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 25 Oct 2022 21:15:05 +0100 Subject: [PATCH] prepare for NCE loss .....but Tensorflow's implementation looks to be for supervised models :-( --- .../ai/components/LayerCheeseMultipleOut.py | 21 ++++++++++++++++++- .../lib/ai/model_rainfallwater_contrastive.py | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py b/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py index 33039e3..87c2e43 100644 --- a/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py +++ b/aimodel/src/lib/ai/components/LayerCheeseMultipleOut.py @@ -1,9 +1,10 @@ +import math import tensorflow as tf class LayerCheeseMultipleOut(tf.keras.layers.Layer): - def __init__(self, **kwargs): + def __init__(self, batch_size, feature_dim, **kwargs): """Creates a new cheese multiple out layer. This layer is useful if you have multiple outputs and a custom loss function that requires multiple inputs. Basically, it just concatenates all inputs. @@ -12,10 +13,28 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer): """ super(LayerCheeseMultipleOut, self).__init__(**kwargs) + self.param_batch_size = batch_size + self.param_feature_dim = feature_dim + self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07])) + self.weight_nce = tf.Variable( + name="loss_nce", + shape=(batch_size, feature_dim), + initial_value=tf.random.truncated_normal( + (feature_dim), + stddev=1.0 / math.sqrt(128) + ) + ) + self.weight_nce_bias = tf.Variable( + name="loss_nce_bias", + shape=(feature_dim), + initial_value=tf.zeros((feature_dim)) + ) def get_config(self): config = super(LayerCheeseMultipleOut, self).get_config() + config["batch_size"] = self.param_batch_size + config["feature_dim"] = self.param_feature_dim return config def call(self, inputs): diff --git a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py index cf970ed..53911e0 100644 --- a/aimodel/src/lib/ai/model_rainfallwater_contrastive.py +++ b/aimodel/src/lib/ai/model_rainfallwater_contrastive.py @@ -49,7 +49,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur )(input_water) - layer_final = LayerCheeseMultipleOut() + layer_final = LayerCheeseMultipleOut(batch_size=batch_size, feature_dim=feature_dim) final = layer_final([ rainfall, water ]) weight_temperature = layer_final.weight_temperature