mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
dlr: tf graph changes
This commit is contained in:
parent
750f46dbd2
commit
0734201107
3 changed files with 42 additions and 35 deletions
|
@ -19,9 +19,9 @@ import tensorflow as tf
|
|||
from lib.dataset.dataset_mono import dataset_mono
|
||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||
from lib.ai.components.MetricSensitivity import sensitivity
|
||||
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
||||
from lib.ai.components.MetricSpecificity import specificity
|
||||
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou
|
||||
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
||||
|
||||
time_start = datetime.now()
|
||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||
|
@ -188,8 +188,8 @@ if PATH_CHECKPOINT is None:
|
|||
metrics=[
|
||||
"accuracy",
|
||||
dice_coefficient,
|
||||
mean_iou,
|
||||
sensitivity, # How many true positives were accurately predicted
|
||||
mean_iou(),
|
||||
sensitivity(), # How many true positives were accurately predicted
|
||||
specificity # How many true negatives were accurately predicted?
|
||||
# TODO: Add IoU, F1, Precision, Recall, here.
|
||||
],
|
||||
|
|
|
@ -3,7 +3,9 @@ import math
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def one_hot_mean_iou(y_true, y_pred, classes=2):
|
||||
def make_one_hot_mean_iou():
|
||||
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
|
||||
def one_hot_mean_iou(y_true, y_pred, classes=2):
|
||||
"""Compute the mean IoU for one-hot tensors.
|
||||
Args:
|
||||
y_true (tf.Tensor): The ground truth label.
|
||||
|
@ -20,7 +22,7 @@ def one_hot_mean_iou(y_true, y_pred, classes=2):
|
|||
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
|
||||
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
|
||||
|
||||
|
||||
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
|
||||
iou.reset_state()
|
||||
iou.update_state(y_true, y_pred)
|
||||
return iou.result()
|
||||
return one_hot_mean_iou
|
|
@ -2,7 +2,10 @@ import math
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
def sensitivity(y_true, y_pred):
|
||||
|
||||
def make_sensitivity():
|
||||
recall = tf.keras.metrics.Recall()
|
||||
def sensitivity(y_true, y_pred):
|
||||
print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape)
|
||||
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
|
||||
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||
|
@ -11,6 +14,8 @@ def sensitivity(y_true, y_pred):
|
|||
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
|
||||
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
|
||||
|
||||
recall = tf.keras.metrics.Recall()
|
||||
recall.reset_state()
|
||||
recall.update_state(y_true, y_pred)
|
||||
return recall.result()
|
||||
|
||||
return _sensitivity
|
Loading…
Reference in a new issue