From 5c6789bf40376eb7de1dce9b8229f31d189174c0 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 3 Mar 2023 22:04:21 +0000 Subject: [PATCH] meaniou: implement one-hot version it expects sparse, but our output is one-hot. --- aimodel/src/deeplabv3_plus_test_rainfall.py | 8 ++++--- .../src/lib/ai/components/MetricMeanIoU.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 aimodel/src/lib/ai/components/MetricMeanIoU.py diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index fe9c7ef..6fad6b3 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -18,9 +18,10 @@ import tensorflow as tf from lib.dataset.dataset_mono import dataset_mono from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice -from lib.ai.components.MetricDice import metric_dice_coefficient +from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient from lib.ai.components.MetricSensitivity import sensitivity from lib.ai.components.MetricSpecificity import specificity +from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou time_start = datetime.now() logger.info(f"Starting at {str(datetime.now().isoformat())}") @@ -165,6 +166,7 @@ else: "MetricDice": MetricDice, "MetricSensitivity": MetricSensitivity, "MetricSpecificity": MetricSpecificity + "MetricMeanIoU": MetricMeanIoU }) @@ -189,8 +191,8 @@ if PATH_CHECKPOINT is None: loss=loss_fn, metrics=[ "accuracy", - metric_dice_coefficient, - tf.keras.metrics.MeanIoU(num_classes=2), + dice_coefficient, + mean_iou, sensitivity, # How many true positives were accurately predicted specificity # How many true negatives were accurately predicted? # TODO: Add IoU, F1, Precision, Recall, here. diff --git a/aimodel/src/lib/ai/components/MetricMeanIoU.py b/aimodel/src/lib/ai/components/MetricMeanIoU.py new file mode 100644 index 0000000..522a1ae --- /dev/null +++ b/aimodel/src/lib/ai/components/MetricMeanIoU.py @@ -0,0 +1,23 @@ +import math + +import tensorflow as tf + + +def one_hot_mean_iou(y_true, y_pred, classes=2): + """Compute the mean IoU for one-hot tensors. + Args: + y_true (tf.Tensor): The ground truth label. + y_pred (tf.Tensor): The output predicted by the model. + + Returns: + tf.Tensor: The computed mean IoU. + """ + + y_pred = tf.math.argmax(y_pred, axis=-1) + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + + + iou = tf.keras.metrics.MeanIoU(classes=classes) + iou.update_state(y_true, y_pred) + return iou.result()