diff --git a/aimodel/src/lib/ai/components/LossDice.py b/aimodel/src/lib/ai/components/LossDice.py index f2aacfb..c2f6cc8 100644 --- a/aimodel/src/lib/ai/components/LossDice.py +++ b/aimodel/src/lib/ai/components/LossDice.py @@ -12,10 +12,11 @@ def dice_coef(y_true, y_pred, smooth=100): Returns: Tensor: The dice coefficient. """ - y_true_f = tf.flatten(y_true) - y_pred_f = K.flatten(y_pred) - intersection = K.sum(y_true_f * y_pred_f) - dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + # K.flatten → tf.reshape; K.sum → tf.math.reduce_sum + y_true_f = tf.reshape(y_true, [-1]) + y_pred_f = tf.reshape(y_pred, [-1]) + intersection = tf.math.reduce_sum(y_true_f * y_pred_f) + dice = (2. * intersection + smooth) / (tf.math.reduce_sum(y_true_f) + tf.math.reduce_sum(y_pred_f) + smooth) return dice def dice_coef_loss(y_true, y_pred, **kwargs):