LossDice: explicitly cast inputs to float32

This commit is contained in:
Starbeamrainbowlabs 2022-12-12 17:20:32 +00:00
parent dbf8f5617c
commit 449bc425a7
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -44,7 +44,9 @@ class LossDice(tf.keras.losses.Loss):
self.param_smooth = smooth
def call(self, y_true, y_pred):
return dice_coef_loss(y_true, y_pred, smooth=self.param_smooth)
ground_truth = tf.cast(y_true, dtype=tf.float32)
prediction = tf.cast(y_pred, dtype=tf.float32)
return dice_coef_loss(ground_truth, prediction, smooth=self.param_smooth)
def get_config(self):
config = super(LossDice, self).get_config()