mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
prepare for NCE loss
.....but Tensorflow's implementation looks to be for supervised models :-(
This commit is contained in:
parent
bb0679a509
commit
98417a3e06
2 changed files with 21 additions and 2 deletions
|
@ -1,9 +1,10 @@
|
|||
import math
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class LayerCheeseMultipleOut(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, batch_size, feature_dim, **kwargs):
|
||||
"""Creates a new cheese multiple out layer.
|
||||
This layer is useful if you have multiple outputs and a custom loss function that requires multiple inputs.
|
||||
Basically, it just concatenates all inputs.
|
||||
|
@ -12,10 +13,28 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer):
|
|||
"""
|
||||
super(LayerCheeseMultipleOut, self).__init__(**kwargs)
|
||||
|
||||
self.param_batch_size = batch_size
|
||||
self.param_feature_dim = feature_dim
|
||||
|
||||
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07]))
|
||||
self.weight_nce = tf.Variable(
|
||||
name="loss_nce",
|
||||
shape=(batch_size, feature_dim),
|
||||
initial_value=tf.random.truncated_normal(
|
||||
(feature_dim),
|
||||
stddev=1.0 / math.sqrt(128)
|
||||
)
|
||||
)
|
||||
self.weight_nce_bias = tf.Variable(
|
||||
name="loss_nce_bias",
|
||||
shape=(feature_dim),
|
||||
initial_value=tf.zeros((feature_dim))
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = super(LayerCheeseMultipleOut, self).get_config()
|
||||
config["batch_size"] = self.param_batch_size
|
||||
config["feature_dim"] = self.param_feature_dim
|
||||
return config
|
||||
|
||||
def call(self, inputs):
|
||||
|
|
|
@ -49,7 +49,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
|
|||
)(input_water)
|
||||
|
||||
|
||||
layer_final = LayerCheeseMultipleOut()
|
||||
layer_final = LayerCheeseMultipleOut(batch_size=batch_size, feature_dim=feature_dim)
|
||||
final = layer_final([ rainfall, water ])
|
||||
weight_temperature = layer_final.weight_temperature
|
||||
|
||||
|
|
Loading…
Reference in a new issue