mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
meaniou: implement one-hot version
it expects sparse, but our output is one-hot.
This commit is contained in:
parent
6ffda40d48
commit
5c6789bf40
2 changed files with 28 additions and 3 deletions
|
@ -18,9 +18,10 @@ import tensorflow as tf
|
|||
|
||||
from lib.dataset.dataset_mono import dataset_mono
|
||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||
from lib.ai.components.MetricDice import metric_dice_coefficient
|
||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||
from lib.ai.components.MetricSensitivity import sensitivity
|
||||
from lib.ai.components.MetricSpecificity import specificity
|
||||
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou
|
||||
|
||||
time_start = datetime.now()
|
||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||
|
@ -165,6 +166,7 @@ else:
|
|||
"MetricDice": MetricDice,
|
||||
"MetricSensitivity": MetricSensitivity,
|
||||
"MetricSpecificity": MetricSpecificity
|
||||
"MetricMeanIoU": MetricMeanIoU
|
||||
})
|
||||
|
||||
|
||||
|
@ -189,8 +191,8 @@ if PATH_CHECKPOINT is None:
|
|||
loss=loss_fn,
|
||||
metrics=[
|
||||
"accuracy",
|
||||
metric_dice_coefficient,
|
||||
tf.keras.metrics.MeanIoU(num_classes=2),
|
||||
dice_coefficient,
|
||||
mean_iou,
|
||||
sensitivity, # How many true positives were accurately predicted
|
||||
specificity # How many true negatives were accurately predicted?
|
||||
# TODO: Add IoU, F1, Precision, Recall, here.
|
||||
|
|
23
aimodel/src/lib/ai/components/MetricMeanIoU.py
Normal file
23
aimodel/src/lib/ai/components/MetricMeanIoU.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def one_hot_mean_iou(y_true, y_pred, classes=2):
|
||||
"""Compute the mean IoU for one-hot tensors.
|
||||
Args:
|
||||
y_true (tf.Tensor): The ground truth label.
|
||||
y_pred (tf.Tensor): The output predicted by the model.
|
||||
|
||||
Returns:
|
||||
tf.Tensor: The computed mean IoU.
|
||||
"""
|
||||
|
||||
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||
y_true = tf.cast(y_true, dtype=tf.float32)
|
||||
y_pred = tf.cast(y_pred, dtype=tf.float32)
|
||||
|
||||
|
||||
iou = tf.keras.metrics.MeanIoU(classes=classes)
|
||||
iou.update_state(y_true, y_pred)
|
||||
return iou.result()
|
Loading…
Reference in a new issue