meaniou: implement one-hot version

it expects sparse, but our output is one-hot.
This commit is contained in:
Starbeamrainbowlabs 2023-03-03 22:04:21 +00:00
parent 6ffda40d48
commit 5c6789bf40
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 28 additions and 3 deletions

View file

@ -18,9 +18,10 @@ import tensorflow as tf
from lib.dataset.dataset_mono import dataset_mono from lib.dataset.dataset_mono import dataset_mono
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricSensitivity import sensitivity from lib.ai.components.MetricSensitivity import sensitivity
from lib.ai.components.MetricSpecificity import specificity from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou
time_start = datetime.now() time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}") logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -165,6 +166,7 @@ else:
"MetricDice": MetricDice, "MetricDice": MetricDice,
"MetricSensitivity": MetricSensitivity, "MetricSensitivity": MetricSensitivity,
"MetricSpecificity": MetricSpecificity "MetricSpecificity": MetricSpecificity
"MetricMeanIoU": MetricMeanIoU
}) })
@ -189,8 +191,8 @@ if PATH_CHECKPOINT is None:
loss=loss_fn, loss=loss_fn,
metrics=[ metrics=[
"accuracy", "accuracy",
metric_dice_coefficient, dice_coefficient,
tf.keras.metrics.MeanIoU(num_classes=2), mean_iou,
sensitivity, # How many true positives were accurately predicted sensitivity, # How many true positives were accurately predicted
specificity # How many true negatives were accurately predicted? specificity # How many true negatives were accurately predicted?
# TODO: Add IoU, F1, Precision, Recall, here. # TODO: Add IoU, F1, Precision, Recall, here.

View file

@ -0,0 +1,23 @@
import math
import tensorflow as tf
def one_hot_mean_iou(y_true, y_pred, classes=2):
"""Compute the mean IoU for one-hot tensors.
Args:
y_true (tf.Tensor): The ground truth label.
y_pred (tf.Tensor): The output predicted by the model.
Returns:
tf.Tensor: The computed mean IoU.
"""
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
iou = tf.keras.metrics.MeanIoU(classes=classes)
iou.update_state(y_true, y_pred)
return iou.result()