diff --git a/aimodel/src/lib/ai/components/MetricDice.py b/aimodel/src/lib/ai/components/MetricDice.py index f2d7abb..956f9f5 100644 --- a/aimodel/src/lib/ai/components/MetricDice.py +++ b/aimodel/src/lib/ai/components/MetricDice.py @@ -23,8 +23,8 @@ def dice_coefficient(y_true, y_pred): def metric_dice_coefficient(y_true, y_pred): + y_pred = tf.math.argmax(y_pred, axis=-1) y_true = tf.cast(y_true, dtype=tf.float32) y_pred = tf.cast(y_pred, dtype=tf.float32) - y_pred = tf.math.argmax(y_pred, axis=-1) return dice_coefficient(y_true, y_pred) \ No newline at end of file