LossDice: remove weird K.* functions

This commit is contained in:
Starbeamrainbowlabs 2022-12-09 19:06:26 +00:00
parent 659fc97fd4
commit 0129c35a35
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -12,10 +12,11 @@ def dice_coef(y_true, y_pred, smooth=100):
Returns:
Tensor: The dice coefficient.
"""
y_true_f = tf.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
# K.flatten → tf.reshape; K.sum → tf.math.reduce_sum
y_true_f = tf.reshape(y_true, [-1])
y_pred_f = tf.reshape(y_pred, [-1])
intersection = tf.math.reduce_sum(y_true_f * y_pred_f)
dice = (2. * intersection + smooth) / (tf.math.reduce_sum(y_true_f) + tf.math.reduce_sum(y_pred_f) + smooth)
return dice
def dice_coef_loss(y_true, y_pred, **kwargs):