mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
prepare for NCE loss
.....but Tensorflow's implementation looks to be for supervised models :-(
This commit is contained in:
parent
bb0679a509
commit
98417a3e06
2 changed files with 21 additions and 2 deletions
|
@ -1,9 +1,10 @@
|
||||||
|
import math
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class LayerCheeseMultipleOut(tf.keras.layers.Layer):
|
class LayerCheeseMultipleOut(tf.keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, batch_size, feature_dim, **kwargs):
|
||||||
"""Creates a new cheese multiple out layer.
|
"""Creates a new cheese multiple out layer.
|
||||||
This layer is useful if you have multiple outputs and a custom loss function that requires multiple inputs.
|
This layer is useful if you have multiple outputs and a custom loss function that requires multiple inputs.
|
||||||
Basically, it just concatenates all inputs.
|
Basically, it just concatenates all inputs.
|
||||||
|
@ -12,10 +13,28 @@ class LayerCheeseMultipleOut(tf.keras.layers.Layer):
|
||||||
"""
|
"""
|
||||||
super(LayerCheeseMultipleOut, self).__init__(**kwargs)
|
super(LayerCheeseMultipleOut, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.param_batch_size = batch_size
|
||||||
|
self.param_feature_dim = feature_dim
|
||||||
|
|
||||||
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07]))
|
self.weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([0.07]))
|
||||||
|
self.weight_nce = tf.Variable(
|
||||||
|
name="loss_nce",
|
||||||
|
shape=(batch_size, feature_dim),
|
||||||
|
initial_value=tf.random.truncated_normal(
|
||||||
|
(feature_dim),
|
||||||
|
stddev=1.0 / math.sqrt(128)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.weight_nce_bias = tf.Variable(
|
||||||
|
name="loss_nce_bias",
|
||||||
|
shape=(feature_dim),
|
||||||
|
initial_value=tf.zeros((feature_dim))
|
||||||
|
)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super(LayerCheeseMultipleOut, self).get_config()
|
config = super(LayerCheeseMultipleOut, self).get_config()
|
||||||
|
config["batch_size"] = self.param_batch_size
|
||||||
|
config["feature_dim"] = self.param_feature_dim
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
|
|
|
@ -49,7 +49,7 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
|
||||||
)(input_water)
|
)(input_water)
|
||||||
|
|
||||||
|
|
||||||
layer_final = LayerCheeseMultipleOut()
|
layer_final = LayerCheeseMultipleOut(batch_size=batch_size, feature_dim=feature_dim)
|
||||||
final = layer_final([ rainfall, water ])
|
final = layer_final([ rainfall, water ])
|
||||||
weight_temperature = layer_final.weight_temperature
|
weight_temperature = layer_final.weight_temperature
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue