mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
LossDice: remove weird K.* functions
This commit is contained in:
parent
659fc97fd4
commit
0129c35a35
1 changed files with 5 additions and 4 deletions
|
@ -12,10 +12,11 @@ def dice_coef(y_true, y_pred, smooth=100):
|
|||
Returns:
|
||||
Tensor: The dice coefficient.
|
||||
"""
|
||||
y_true_f = tf.flatten(y_true)
|
||||
y_pred_f = K.flatten(y_pred)
|
||||
intersection = K.sum(y_true_f * y_pred_f)
|
||||
dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
|
||||
# K.flatten → tf.reshape; K.sum → tf.math.reduce_sum
|
||||
y_true_f = tf.reshape(y_true, [-1])
|
||||
y_pred_f = tf.reshape(y_pred, [-1])
|
||||
intersection = tf.math.reduce_sum(y_true_f * y_pred_f)
|
||||
dice = (2. * intersection + smooth) / (tf.math.reduce_sum(y_true_f) + tf.math.reduce_sum(y_pred_f) + smooth)
|
||||
return dice
|
||||
|
||||
def dice_coef_loss(y_true, y_pred, **kwargs):
|
||||
|
|
Loading…
Reference in a new issue