From 0129c35a357e2e5d50c234fbfd0be4bdf75453a2 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 9 Dec 2022 19:06:26 +0000 Subject: [PATCH] LossDice: remove weird K.* functions --- aimodel/src/lib/ai/components/LossDice.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/aimodel/src/lib/ai/components/LossDice.py b/aimodel/src/lib/ai/components/LossDice.py index f2aacfb..c2f6cc8 100644 --- a/aimodel/src/lib/ai/components/LossDice.py +++ b/aimodel/src/lib/ai/components/LossDice.py @@ -12,10 +12,11 @@ def dice_coef(y_true, y_pred, smooth=100): Returns: Tensor: The dice coefficient. """ - y_true_f = tf.flatten(y_true) - y_pred_f = K.flatten(y_pred) - intersection = K.sum(y_true_f * y_pred_f) - dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + # K.flatten → tf.reshape; K.sum → tf.math.reduce_sum + y_true_f = tf.reshape(y_true, [-1]) + y_pred_f = tf.reshape(y_pred, [-1]) + intersection = tf.math.reduce_sum(y_true_f * y_pred_f) + dice = (2. * intersection + smooth) / (tf.math.reduce_sum(y_true_f) + tf.math.reduce_sum(y_pred_f) + smooth) return dice def dice_coef_loss(y_true, y_pred, **kwargs):