mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 17:23:01 +00:00
LossDice: remove weird K.* functions
This commit is contained in:
parent
659fc97fd4
commit
0129c35a35
1 changed files with 5 additions and 4 deletions
|
@ -12,10 +12,11 @@ def dice_coef(y_true, y_pred, smooth=100):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The dice coefficient.
|
Tensor: The dice coefficient.
|
||||||
"""
|
"""
|
||||||
y_true_f = tf.flatten(y_true)
|
# K.flatten → tf.reshape; K.sum → tf.math.reduce_sum
|
||||||
y_pred_f = K.flatten(y_pred)
|
y_true_f = tf.reshape(y_true, [-1])
|
||||||
intersection = K.sum(y_true_f * y_pred_f)
|
y_pred_f = tf.reshape(y_pred, [-1])
|
||||||
dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
|
intersection = tf.math.reduce_sum(y_true_f * y_pred_f)
|
||||||
|
dice = (2. * intersection + smooth) / (tf.math.reduce_sum(y_true_f) + tf.math.reduce_sum(y_pred_f) + smooth)
|
||||||
return dice
|
return dice
|
||||||
|
|
||||||
def dice_coef_loss(y_true, y_pred, **kwargs):
|
def dice_coef_loss(y_true, y_pred, **kwargs):
|
||||||
|
|
Loading…
Reference in a new issue