diff --git a/aimodel/src/lib/ai/components/MetricMeanIoU.py b/aimodel/src/lib/ai/components/MetricMeanIoU.py index 470220f..6469472 100644 --- a/aimodel/src/lib/ai/components/MetricMeanIoU.py +++ b/aimodel/src/lib/ai/components/MetricMeanIoU.py @@ -3,9 +3,9 @@ import math import tensorflow as tf -def make_one_hot_mean_iou(): +def make_one_hot_mean_iou(classes=2): iou = tf.keras.metrics.MeanIoU(num_classes=classes) - def one_hot_mean_iou(y_true, y_pred, classes=2): + def one_hot_mean_iou(y_true, y_pred, ): """Compute the mean IoU for one-hot tensors. Args: y_true (tf.Tensor): The ground truth label.