dlr: tf graph changes

This commit is contained in:
Starbeamrainbowlabs 2023-03-03 22:44:49 +00:00
parent 750f46dbd2
commit 0734201107
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 42 additions and 35 deletions

View file

@ -19,9 +19,9 @@ import tensorflow as tf
from lib.dataset.dataset_mono import dataset_mono from lib.dataset.dataset_mono import dataset_mono
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricSensitivity import sensitivity from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
time_start = datetime.now() time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}") logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -188,8 +188,8 @@ if PATH_CHECKPOINT is None:
metrics=[ metrics=[
"accuracy", "accuracy",
dice_coefficient, dice_coefficient,
mean_iou, mean_iou(),
sensitivity, # How many true positives were accurately predicted sensitivity(), # How many true positives were accurately predicted
specificity # How many true negatives were accurately predicted? specificity # How many true negatives were accurately predicted?
# TODO: Add IoU, F1, Precision, Recall, here. # TODO: Add IoU, F1, Precision, Recall, here.
], ],

View file

@ -3,24 +3,26 @@ import math
import tensorflow as tf import tensorflow as tf
def one_hot_mean_iou(y_true, y_pred, classes=2): def make_one_hot_mean_iou():
"""Compute the mean IoU for one-hot tensors.
Args:
y_true (tf.Tensor): The ground truth label.
y_pred (tf.Tensor): The output predicted by the model.
Returns:
tf.Tensor: The computed mean IoU.
"""
print("DEBUG:meaniou y_pred.shape BEFORE", y_pred.shape)
print("DEBUG:meaniou y_true.shape BEFORE", y_true.shape)
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
iou = tf.keras.metrics.MeanIoU(num_classes=classes) iou = tf.keras.metrics.MeanIoU(num_classes=classes)
iou.update_state(y_true, y_pred) def one_hot_mean_iou(y_true, y_pred, classes=2):
return iou.result() """Compute the mean IoU for one-hot tensors.
Args:
y_true (tf.Tensor): The ground truth label.
y_pred (tf.Tensor): The output predicted by the model.
Returns:
tf.Tensor: The computed mean IoU.
"""
print("DEBUG:meaniou y_pred.shape BEFORE", y_pred.shape)
print("DEBUG:meaniou y_true.shape BEFORE", y_true.shape)
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape)
print("DEBUG:meaniou y_true.shape AFTER", y_true.shape)
iou.reset_state()
iou.update_state(y_true, y_pred)
return iou.result()
return one_hot_mean_iou

View file

@ -2,15 +2,20 @@ import math
import tensorflow as tf import tensorflow as tf
def sensitivity(y_true, y_pred):
print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape) def make_sensitivity():
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
recall = tf.keras.metrics.Recall() recall = tf.keras.metrics.Recall()
recall.update_state(y_true, y_pred) def sensitivity(y_true, y_pred):
return recall.result() print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape)
print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape)
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape)
print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape)
recall.reset_state()
recall.update_state(y_true, y_pred)
return recall.result()
return _sensitivity