mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-16 14:43:01 +00:00
dlr: tf graph changes
This commit is contained in:
parent
750f46dbd2
commit
0734201107
3 changed files with 42 additions and 35 deletions
|
@ -19,9 +19,9 @@ import tensorflow as tf
|
||||||
from lib.dataset.dataset_mono import dataset_mono
|
from lib.dataset.dataset_mono import dataset_mono
|
||||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||||
from lib.ai.components.MetricSensitivity import sensitivity
|
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
||||||
from lib.ai.components.MetricSpecificity import specificity
|
from lib.ai.components.MetricSpecificity import specificity
|
||||||
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou
|
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
||||||
|
|
||||||
time_start = datetime.now()
|
time_start = datetime.now()
|
||||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||||
|
@ -188,8 +188,8 @@ if PATH_CHECKPOINT is None:
|
||||||
metrics=[
|
metrics=[
|
||||||
"accuracy",
|
"accuracy",
|
||||||
dice_coefficient,
|
dice_coefficient,
|
||||||
mean_iou,
|
mean_iou(),
|
||||||
sensitivity, # How many true positives were accurately predicted
|
sensitivity(), # How many true positives were accurately predicted
|
||||||
specificity # How many true negatives were accurately predicted?
|
specificity # How many true negatives were accurately predicted?
|
||||||
# TODO: Add IoU, F1, Precision, Recall, here.
|
# TODO: Add IoU, F1, Precision, Recall, here.
|
||||||
],
|
],
|
||||||
|
|
|
@ -3,7 +3,9 @@ import math
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def one_hot_mean_iou(y_true, y_pred, classes=2):
|
def make_one_hot_mean_iou():
|
||||||
|
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
|
||||||
|
def one_hot_mean_iou(y_true, y_pred, classes=2):
|
||||||
"""Compute the mean IoU for one-hot tensors.
|
"""Compute the mean IoU for one-hot tensors.
|
||||||
Args:
|
Args:
|
||||||
y_true (tf.Tensor): The ground truth label.
|
y_true (tf.Tensor): The ground truth label.
|
||||||
|
@ -20,7 +22,7 @@ def one_hot_mean_iou(y_true, y_pred, classes=2):
|
||||||
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
|
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
|
||||||
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
|
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
|
||||||
|
|
||||||
|
iou.reset_state()
|
||||||
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
|
|
||||||
iou.update_state(y_true, y_pred)
|
iou.update_state(y_true, y_pred)
|
||||||
return iou.result()
|
return iou.result()
|
||||||
|
return one_hot_mean_iou
|
|
@ -2,7 +2,10 @@ import math
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def sensitivity(y_true, y_pred):
|
|
||||||
|
def make_sensitivity():
|
||||||
|
recall = tf.keras.metrics.Recall()
|
||||||
|
def sensitivity(y_true, y_pred):
|
||||||
print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape)
|
print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape)
|
||||||
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
|
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
|
||||||
y_pred = tf.math.argmax(y_pred, axis=-1)
|
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
@ -11,6 +14,8 @@ def sensitivity(y_true, y_pred):
|
||||||
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
|
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
|
||||||
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
|
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
|
||||||
|
|
||||||
recall = tf.keras.metrics.Recall()
|
recall.reset_state()
|
||||||
recall.update_state(y_true, y_pred)
|
recall.update_state(y_true, y_pred)
|
||||||
return recall.result()
|
return recall.result()
|
||||||
|
|
||||||
|
return _sensitivity
|
Loading…
Reference in a new issue