From 0734201107053d18048c0360e14453454ff62afd Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 3 Mar 2023 22:44:49 +0000 Subject: [PATCH] dlr: tf graph changes --- aimodel/src/deeplabv3_plus_test_rainfall.py | 8 ++-- .../src/lib/ai/components/MetricMeanIoU.py | 42 ++++++++++--------- .../lib/ai/components/MetricSensitivity.py | 27 +++++++----- 3 files changed, 42 insertions(+), 35 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 9102ebd..ad0681c 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -19,9 +19,9 @@ import tensorflow as tf from lib.dataset.dataset_mono import dataset_mono from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient -from lib.ai.components.MetricSensitivity import sensitivity +from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity from lib.ai.components.MetricSpecificity import specificity -from lib.ai.components.MetricMeanIoU import one_hot_mean_iou as mean_iou +from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou time_start = datetime.now() logger.info(f"Starting at {str(datetime.now().isoformat())}") @@ -188,8 +188,8 @@ if PATH_CHECKPOINT is None: metrics=[ "accuracy", dice_coefficient, - mean_iou, - sensitivity, # How many true positives were accurately predicted + mean_iou(), + sensitivity(), # How many true positives were accurately predicted specificity # How many true negatives were accurately predicted? # TODO: Add IoU, F1, Precision, Recall, here. ], diff --git a/aimodel/src/lib/ai/components/MetricMeanIoU.py b/aimodel/src/lib/ai/components/MetricMeanIoU.py index 523c2f6..470220f 100644 --- a/aimodel/src/lib/ai/components/MetricMeanIoU.py +++ b/aimodel/src/lib/ai/components/MetricMeanIoU.py @@ -3,24 +3,26 @@ import math import tensorflow as tf -def one_hot_mean_iou(y_true, y_pred, classes=2): - """Compute the mean IoU for one-hot tensors. - Args: - y_true (tf.Tensor): The ground truth label. - y_pred (tf.Tensor): The output predicted by the model. - - Returns: - tf.Tensor: The computed mean IoU. - """ - print("DEBUG:meaniou y_pred.shape BEFORE", y_pred.shape) - print("DEBUG:meaniou y_true.shape BEFORE", y_true.shape) - y_pred = tf.math.argmax(y_pred, axis=-1) - y_true = tf.cast(y_true, dtype=tf.float32) - y_pred = tf.cast(y_pred, dtype=tf.float32) - print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape) - print("DEBUG:meaniou y_true.shape AFTER", y_true.shape) - - +def make_one_hot_mean_iou(): iou = tf.keras.metrics.MeanIoU(num_classes=classes) - iou.update_state(y_true, y_pred) - return iou.result() + def one_hot_mean_iou(y_true, y_pred, classes=2): + """Compute the mean IoU for one-hot tensors. + Args: + y_true (tf.Tensor): The ground truth label. + y_pred (tf.Tensor): The output predicted by the model. + + Returns: + tf.Tensor: The computed mean IoU. + """ + print("DEBUG:meaniou y_pred.shape BEFORE", y_pred.shape) + print("DEBUG:meaniou y_true.shape BEFORE", y_true.shape) + y_pred = tf.math.argmax(y_pred, axis=-1) + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + print("DEBUG:meaniou y_pred.shape AFTER", y_pred.shape) + print("DEBUG:meaniou y_true.shape AFTER", y_true.shape) + + iou.reset_state() + iou.update_state(y_true, y_pred) + return iou.result() + return one_hot_mean_iou \ No newline at end of file diff --git a/aimodel/src/lib/ai/components/MetricSensitivity.py b/aimodel/src/lib/ai/components/MetricSensitivity.py index 1839e36..74f2376 100644 --- a/aimodel/src/lib/ai/components/MetricSensitivity.py +++ b/aimodel/src/lib/ai/components/MetricSensitivity.py @@ -2,15 +2,20 @@ import math import tensorflow as tf -def sensitivity(y_true, y_pred): - print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape) - print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape) - y_pred = tf.math.argmax(y_pred, axis=-1) - y_true = tf.cast(y_true, dtype=tf.float32) - y_pred = tf.cast(y_pred, dtype=tf.float32) - print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape) - print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape) - + +def make_sensitivity(): recall = tf.keras.metrics.Recall() - recall.update_state(y_true, y_pred) - return recall.result() + def sensitivity(y_true, y_pred): + print("DEBUG:sensitivity y_pred.shape BEFORE", y_pred.shape) + print("DEBUG:sensitivity y_true.shape BEFORE", y_true.shape) + y_pred = tf.math.argmax(y_pred, axis=-1) + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + print("DEBUG:sensitivity y_pred.shape AFTER", y_pred.shape) + print("DEBUG:sensitivity y_true.shape AFTER", y_true.shape) + + recall.reset_state() + recall.update_state(y_true, y_pred) + return recall.result() + + return _sensitivity \ No newline at end of file