mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-26 02:43:02 +00:00
LossDice: explicitly cast inputs to float32
This commit is contained in:
parent
dbf8f5617c
commit
449bc425a7
1 changed files with 3 additions and 1 deletions
|
@ -44,7 +44,9 @@ class LossDice(tf.keras.losses.Loss):
|
||||||
self.param_smooth = smooth
|
self.param_smooth = smooth
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
return dice_coef_loss(y_true, y_pred, smooth=self.param_smooth)
|
ground_truth = tf.cast(y_true, dtype=tf.float32)
|
||||||
|
prediction = tf.cast(y_pred, dtype=tf.float32)
|
||||||
|
return dice_coef_loss(ground_truth, prediction, smooth=self.param_smooth)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super(LossDice, self).get_config()
|
config = super(LossDice, self).get_config()
|
||||||
|
|
Loading…
Reference in a new issue