From b435cc54ddffd056318bc0caa20b4e215f8b032e Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 3 Mar 2023 20:00:05 +0000 Subject: [PATCH] dlr: add sensitivity (aka recall) and specificity metrics --- aimodel/src/deeplabv3_plus_test_rainfall.py | 10 ++++-- .../lib/ai/components/MetricSensitivity.py | 31 +++++++++++++++++++ .../lib/ai/components/MetricSpecificity.py | 7 ++--- 3 files changed, 42 insertions(+), 6 deletions(-) create mode 100644 aimodel/src/lib/ai/components/MetricSensitivity.py diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 11d074d..b9d9980 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -19,6 +19,8 @@ import tensorflow as tf from lib.dataset.dataset_mono import dataset_mono from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.MetricDice import MetricDice +from lia.ai.components.MetricSensitivity import MetricSensitivity +from lib.ai.components.MetricSpecificity import MetricSpecificity time_start = datetime.now() logger.info(f"Starting at {str(datetime.now().isoformat())}") @@ -160,7 +162,9 @@ else: model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={ # Tell Tensorflow about our custom layers so that it can deserialise models that use them "LossCrossEntropyDice": LossCrossEntropyDice, - "MetricDice": MetricDice + "MetricDice": MetricDice, + "MetricSensitivity": MetricSensitivity, + "MetricSpecificity": MetricSpecificity }) @@ -186,7 +190,9 @@ if PATH_CHECKPOINT is None: metrics=[ "accuracy", MetricDice(), - tf.keras.metrics.MeanIoU(num_classes=2) + tf.keras.metrics.MeanIoU(num_classes=2), + MetricSensitivity(), # How many true positives were accurately predicted + MetricSpecificity() # How many true negatives were accurately predicted? # TODO: Add IoU, F1, Precision, Recall, here. ], ) diff --git a/aimodel/src/lib/ai/components/MetricSensitivity.py b/aimodel/src/lib/ai/components/MetricSensitivity.py new file mode 100644 index 0000000..641e0c8 --- /dev/null +++ b/aimodel/src/lib/ai/components/MetricSensitivity.py @@ -0,0 +1,31 @@ +import math + +import tensorflow as tf + + + +class MetricSensitivity(tf.keras.metrics.Metric): + """An implementation of the sensitivity. + Also known as Recall. In other words, how many of the true positives were accurately predicted. + @source + Args: + smooth (float): The batch size (currently unused). + """ + + def __init__(self, name="sensitivity", **kwargs): + super(MetricSensitivity, self).__init__(name=name) + + self.recall = tf.keras.metrics.Recall(**kwargs) + + def call(self, y_true, y_pred): + ground_truth = tf.cast(y_true, dtype=tf.float32) + prediction = tf.cast(y_pred, dtype=tf.float32) + + return self.recall(y_true, y_pred) + + def get_config(self): + config = super(MetricSensitivity, self).get_config() + config.update({ + + }) + return config diff --git a/aimodel/src/lib/ai/components/MetricSpecificity.py b/aimodel/src/lib/ai/components/MetricSpecificity.py index d0a9805..e3a2247 100644 --- a/aimodel/src/lib/ai/components/MetricSpecificity.py +++ b/aimodel/src/lib/ai/components/MetricSpecificity.py @@ -21,7 +21,8 @@ def specificity(y_pred, y_true): class MetricSpecificity(tf.keras.metrics.Metric): - """An implementation of the sensitivity. + """An implementation of the specificity. + In other words, a measure of how many of the true negatives were accurately predicted @source Args: smooth (float): The batch size (currently unused). @@ -29,8 +30,6 @@ class MetricSpecificity(tf.keras.metrics.Metric): def __init__(self, name="specificity", **kwargs): super(MetricSpecificity, self).__init__(name=name, **kwargs) - - self.param_smooth = smooth def call(self, y_true, y_pred): ground_truth = tf.cast(y_true, dtype=tf.float32) @@ -41,6 +40,6 @@ class MetricSpecificity(tf.keras.metrics.Metric): def get_config(self): config = super(MetricSpecificity, self).get_config() config.update({ - "smooth": self.param_smooth, + }) return config