diff --git a/aimodel/src/lib/ai/components/MetricDice.py b/aimodel/src/lib/ai/components/MetricDice.py index 9ac35ec..96e19f4 100644 --- a/aimodel/src/lib/ai/components/MetricDice.py +++ b/aimodel/src/lib/ai/components/MetricDice.py @@ -26,5 +26,5 @@ def dice_coefficient(y_true, y_pred): def metric_dice_coefficient(y_true, y_pred): - y_pred = tf.math.argmax(y_pred) + y_pred = tf.math.argmax(y_pred, axis=-1) return dice_coefficient(y_true, y_pred) \ No newline at end of file