From 7453c607edb28d0ca63886cd8b94fb9f71194e41 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 3 Mar 2023 21:49:33 +0000 Subject: [PATCH] argmax for sensitivity & specificity too --- aimodel/src/lib/ai/components/MetricDice.py | 6 +++--- aimodel/src/lib/ai/components/MetricSensitivity.py | 6 ++++-- aimodel/src/lib/ai/components/MetricSpecificity.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/aimodel/src/lib/ai/components/MetricDice.py b/aimodel/src/lib/ai/components/MetricDice.py index 96e19f4..f2d7abb 100644 --- a/aimodel/src/lib/ai/components/MetricDice.py +++ b/aimodel/src/lib/ai/components/MetricDice.py @@ -15,9 +15,6 @@ def dice_coefficient(y_true, y_pred): tf.Tensor: The computed Dice coefficient. """ - y_true = tf.cast(y_true, dtype=tf.float32) - y_pred = tf.cast(y_pred, dtype=tf.float32) - y_pred = tf.math.sigmoid(y_pred) numerator = 2 * tf.reduce_sum(y_true * y_pred) denominator = tf.reduce_sum(y_true + y_pred) @@ -26,5 +23,8 @@ def dice_coefficient(y_true, y_pred): def metric_dice_coefficient(y_true, y_pred): + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + y_pred = tf.math.argmax(y_pred, axis=-1) return dice_coefficient(y_true, y_pred) \ No newline at end of file diff --git a/aimodel/src/lib/ai/components/MetricSensitivity.py b/aimodel/src/lib/ai/components/MetricSensitivity.py index b04292f..4060d7c 100644 --- a/aimodel/src/lib/ai/components/MetricSensitivity.py +++ b/aimodel/src/lib/ai/components/MetricSensitivity.py @@ -3,8 +3,10 @@ import math import tensorflow as tf def sensitivity(y_true, y_pred): - ground_truth = tf.cast(y_true, dtype=tf.float32) - prediction = tf.cast(y_pred, dtype=tf.float32) + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + + y_pred = tf.math.argmax(y_pred, axis=-1) recall = tf.keras.metrics.Recall() recall.update_state(y_true, y_pred) diff --git a/aimodel/src/lib/ai/components/MetricSpecificity.py b/aimodel/src/lib/ai/components/MetricSpecificity.py index 51598c4..60b4599 100644 --- a/aimodel/src/lib/ai/components/MetricSpecificity.py +++ b/aimodel/src/lib/ai/components/MetricSpecificity.py @@ -13,10 +13,11 @@ def specificity(y_pred, y_true): Returns: Specificity score """ - y_true = tf.cast(y_true, dtype=tf.float32) y_pred = tf.cast(y_pred, dtype=tf.float32) + y_pred = tf.math.argmax(y_pred, axis=-1) + neg_y_true = 1 - y_true neg_y_pred = 1 - y_pred fp = K.sum(neg_y_true * y_pred)