From 37d1598b0b70df03472284039d23f4e53fac7606 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 13 Jan 2023 18:21:11 +0000 Subject: [PATCH] loss cel+dice: fixup --- aimodel/src/lib/ai/components/LossCrossEntropyDice.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aimodel/src/lib/ai/components/LossCrossEntropyDice.py b/aimodel/src/lib/ai/components/LossCrossEntropyDice.py index ae07525..ce20cd7 100644 --- a/aimodel/src/lib/ai/components/LossCrossEntropyDice.py +++ b/aimodel/src/lib/ai/components/LossCrossEntropyDice.py @@ -22,17 +22,19 @@ def dice_loss(y_true, y_pred): class LossCrossEntropyDice(tf.keras.losses.Loss): """Cross-entropy loss and dice loss combined together into one nice neat package. Combines the two with mean. + The ground truth labels should sparse, NOT one-hot. The predictions should be one-hot, NOT sparse. @source https://lars76.github.io/2018/09/27/loss-functions-for-segmentation.html#9 """ - + def __init__(self, **kwargs): super(LossCrossEntropyDice, self).__init__(**kwargs) - + def call(self, y_true, y_pred): y_true = tf.cast(y_true, tf.float32) + y_true = tf.one_hot(y_true, 2) # Input is sparse o = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred) + dice_loss(y_true, y_pred) return tf.reduce_mean(o) - + def get_config(self): config = super(LossCrossEntropyDice, self).get_config() config.update({