dlr: tf graph changes

This commit is contained in:
Starbeamrainbowlabs 2023-03-03 22:44:49 +00:00
parent 750f46dbd2
commit 0734201107
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 42 additions and 35 deletions

View file

@ -19,9 +19,9 @@ import tensorflow as tf
from lib.dataset.dataset_mono import dataset_mono from lib.dataset.dataset_mono import dataset_mono
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricSensitivity import sensitivity from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
time_start = datetime.now() time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}") logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -188,8 +188,8 @@ if PATH_CHECKPOINT is None:
metrics=[ metrics=[
"accuracy", "accuracy",
dice_coefficient, dice_coefficient,
mean_iou, mean_iou(),
sensitivity, # How many true positives were accurately predicted sensitivity(), # How many true positives were accurately predicted
specificity # How many true negatives were accurately predicted? specificity # How many true negatives were accurately predicted?
# TODO: Add IoU, F1, Precision, Recall, here. # TODO: Add IoU, F1, Precision, Recall, here.
], ],

View file

@ -3,7 +3,9 @@ import math
import tensorflow as tf import tensorflow as tf
def one_hot_mean_iou(y_true, y_pred, classes=2): def make_one_hot_mean_iou():
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
def one_hot_mean_iou(y_true, y_pred, classes=2):
"""Compute the mean IoU for one-hot tensors. """Compute the mean IoU for one-hot tensors.
Args: Args:
y_true (tf.Tensor): The ground truth label. y_true (tf.Tensor): The ground truth label.
@ -20,7 +22,7 @@ def one_hot_mean_iou(y_true, y_pred, classes=2):
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape) print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape) print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
iou.reset_state()
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
iou.update_state(y_true, y_pred) iou.update_state(y_true, y_pred)
return iou.result() return iou.result()
return one_hot_mean_iou

View file

@ -2,7 +2,10 @@ import math
import tensorflow as tf import tensorflow as tf
def sensitivity(y_true, y_pred):
def make_sensitivity():
recall = tf.keras.metrics.Recall()
def sensitivity(y_true, y_pred):
print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape) print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape)
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape) print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
y_pred = tf.math.argmax(y_pred, axis=-1) y_pred = tf.math.argmax(y_pred, axis=-1)
@ -11,6 +14,8 @@ def sensitivity(y_true, y_pred):
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape) print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape) print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
recall = tf.keras.metrics.Recall() recall.reset_state()
recall.update_state(y_true, y_pred) recall.update_state(y_true, y_pred)
return recall.result() return recall.result()
return _sensitivity