mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-16 14:43:01 +00:00
again
This commit is contained in:
parent
e04d6ab1b6
commit
2f0ce0aa13
1 changed files with 1 additions and 1 deletions
|
@ -31,7 +31,7 @@ class LossCrossEntropyDice(tf.keras.losses.Loss):
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
y_true = tf.cast(y_true, tf.float32)
|
y_true = tf.cast(y_true, tf.float32)
|
||||||
y_true = tf.cast(tf.one_hot(y_true, 2), dtype=tf.int32) # Input is sparse
|
y_true = tf.one_hot(tf.cast(y_true, dtype=tf.int32), 2) # Input is sparse
|
||||||
o = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred) + dice_loss(y_true, y_pred)
|
o = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred) + dice_loss(y_true, y_pred)
|
||||||
return tf.reduce_mean(o)
|
return tf.reduce_mean(o)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue