mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
meaniou: implement one-hot version
it expects sparse, but our output is one-hot.
This commit is contained in:
parent
6ffda40d48
commit
5c6789bf40
2 changed files with 28 additions and 3 deletions
|
@ -18,9 +18,10 @@ import tensorflow as tf
|
||||||
|
|
||||||
from lib.dataset.dataset_mono import dataset_mono
|
from lib.dataset.dataset_mono import dataset_mono
|
||||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||||
from lib.ai.components.MetricDice import metric_dice_coefficient
|
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||||
from lib.ai.components.MetricSensitivity import sensitivity
|
from lib.ai.components.MetricSensitivity import sensitivity
|
||||||
from lib.ai.components.MetricSpecificity import specificity
|
from lib.ai.components.MetricSpecificity import specificity
|
||||||
|
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou
|
||||||
|
|
||||||
time_start = datetime.now()
|
time_start = datetime.now()
|
||||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||||
|
@ -165,6 +166,7 @@ else:
|
||||||
"MetricDice": MetricDice,
|
"MetricDice": MetricDice,
|
||||||
"MetricSensitivity": MetricSensitivity,
|
"MetricSensitivity": MetricSensitivity,
|
||||||
"MetricSpecificity": MetricSpecificity
|
"MetricSpecificity": MetricSpecificity
|
||||||
|
"MetricMeanIoU": MetricMeanIoU
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,8 +191,8 @@ if PATH_CHECKPOINT is None:
|
||||||
loss=loss_fn,
|
loss=loss_fn,
|
||||||
metrics=[
|
metrics=[
|
||||||
"accuracy",
|
"accuracy",
|
||||||
metric_dice_coefficient,
|
dice_coefficient,
|
||||||
tf.keras.metrics.MeanIoU(num_classes=2),
|
mean_iou,
|
||||||
sensitivity, # How many true positives were accurately predicted
|
sensitivity, # How many true positives were accurately predicted
|
||||||
specificity # How many true negatives were accurately predicted?
|
specificity # How many true negatives were accurately predicted?
|
||||||
# TODO: Add IoU, F1, Precision, Recall, here.
|
# TODO: Add IoU, F1, Precision, Recall, here.
|
||||||
|
|
23
aimodel/src/lib/ai/components/MetricMeanIoU.py
Normal file
23
aimodel/src/lib/ai/components/MetricMeanIoU.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def one_hot_mean_iou(y_true, y_pred, classes=2):
|
||||||
|
"""Compute the mean IoU for one-hot tensors.
|
||||||
|
Args:
|
||||||
|
y_true (tf.Tensor): The ground truth label.
|
||||||
|
y_pred (tf.Tensor): The output predicted by the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor: The computed mean IoU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
y_true = tf.cast(y_true, dtype=tf.float32)
|
||||||
|
y_pred = tf.cast(y_pred, dtype=tf.float32)
|
||||||
|
|
||||||
|
|
||||||
|
iou = tf.keras.metrics.MeanIoU(classes=classes)
|
||||||
|
iou.update_state(y_true, y_pred)
|
||||||
|
return iou.result()
|
Loading…
Reference in a new issue