mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
LossDice: explicitly cast inputs to float32
This commit is contained in:
parent
dbf8f5617c
commit
449bc425a7
1 changed files with 3 additions and 1 deletions
|
@ -44,7 +44,9 @@ class LossDice(tf.keras.losses.Loss):
|
|||
self.param_smooth = smooth
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
return dice_coef_loss(y_true, y_pred, smooth=self.param_smooth)
|
||||
ground_truth = tf.cast(y_true, dtype=tf.float32)
|
||||
prediction = tf.cast(y_pred, dtype=tf.float32)
|
||||
return dice_coef_loss(ground_truth, prediction, smooth=self.param_smooth)
|
||||
|
||||
def get_config(self):
|
||||
config = super(LossDice, self).get_config()
|
||||
|
|
Loading…
Reference in a new issue