dlr: tf graph changes

This commit is contained in:
Starbeamrainbowlabs 2023-03-03 22:44:49 +00:00
parent 750f46dbd2
commit 0734201107
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 42 additions and 35 deletions

View file

@ -19,9 +19,9 @@ import tensorflow as tf
from lib.dataset.dataset_mono import dataset_mono
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricSensitivity import sensitivity
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -188,8 +188,8 @@ if PATH_CHECKPOINT is None:
metrics=[
"accuracy",
dice_coefficient,
mean_iou,
sensitivity, # How many true positives were accurately predicted
mean_iou(),
sensitivity(), # How many true positives were accurately predicted
specificity # How many true negatives were accurately predicted?
# TODO: Add IoU, F1, Precision, Recall, here.
],

View file

@ -3,7 +3,9 @@ import math
import tensorflow as tf
def one_hot_mean_iou(y_true, y_pred, classes=2):
def make_one_hot_mean_iou():
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
def one_hot_mean_iou(y_true, y_pred, classes=2):
"""Compute the mean IoU for one-hot tensors.
Args:
y_true (tf.Tensor): The ground truth label.
@ -20,7 +22,7 @@ def one_hot_mean_iou(y_true, y_pred, classes=2):
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
iou.reset_state()
iou.update_state(y_true, y_pred)
return iou.result()
return one_hot_mean_iou

View file

@ -2,7 +2,10 @@ import math
import tensorflow as tf
def sensitivity(y_true, y_pred):
def make_sensitivity():
recall = tf.keras.metrics.Recall()
def sensitivity(y_true, y_pred):
print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape)
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
y_pred = tf.math.argmax(y_pred, axis=-1)
@ -11,6 +14,8 @@ def sensitivity(y_true, y_pred):
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
recall = tf.keras.metrics.Recall()
recall.reset_state()
recall.update_state(y_true, y_pred)
return recall.result()
return _sensitivity