diff --git a/aimodel/src/lib/ai/components/LossDice.py b/aimodel/src/lib/ai/components/LossDice.py index 5abecd5..2f5c0c5 100644 --- a/aimodel/src/lib/ai/components/LossDice.py +++ b/aimodel/src/lib/ai/components/LossDice.py @@ -44,7 +44,9 @@ class LossDice(tf.keras.losses.Loss): self.param_smooth = smooth def call(self, y_true, y_pred): - return dice_coef_loss(y_true, y_pred, smooth=self.param_smooth) + ground_truth = tf.cast(y_true, dtype=tf.float32) + prediction = tf.cast(y_pred, dtype=tf.float32) + return dice_coef_loss(ground_truth, prediction, smooth=self.param_smooth) def get_config(self): config = super(LossDice, self).get_config()