From 3d051a887403356a279e7c1d91ff2220a1e586e6 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 3 Mar 2023 21:41:26 +0000 Subject: [PATCH] =?UTF-8?q?dlr:=20HACK:=20argmax=20to=20convert=20[64,128,?= =?UTF-8?q?128,=202]=20=E2=86=92=20[64,128,128]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aimodel/src/deeplabv3_plus_test_rainfall.py | 4 +-- aimodel/src/lib/ai/components/MetricDice.py | 33 ++++++--------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 5cca40b..fe9c7ef 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -18,7 +18,7 @@ import tensorflow as tf from lib.dataset.dataset_mono import dataset_mono from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice -from lib.ai.components.MetricDice import dice_coefficient +from lib.ai.components.MetricDice import metric_dice_coefficient from lib.ai.components.MetricSensitivity import sensitivity from lib.ai.components.MetricSpecificity import specificity @@ -189,7 +189,7 @@ if PATH_CHECKPOINT is None: loss=loss_fn, metrics=[ "accuracy", - dice_coefficient, + metric_dice_coefficient, tf.keras.metrics.MeanIoU(num_classes=2), sensitivity, # How many true positives were accurately predicted specificity # How many true negatives were accurately predicted? diff --git a/aimodel/src/lib/ai/components/MetricDice.py b/aimodel/src/lib/ai/components/MetricDice.py index d305e49..9ac35ec 100644 --- a/aimodel/src/lib/ai/components/MetricDice.py +++ b/aimodel/src/lib/ai/components/MetricDice.py @@ -14,32 +14,17 @@ def dice_coefficient(y_true, y_pred): Returns: tf.Tensor: The computed Dice coefficient. """ + + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + y_pred = tf.math.sigmoid(y_pred) numerator = 2 * tf.reduce_sum(y_true * y_pred) denominator = tf.reduce_sum(y_true + y_pred) - + return numerator / denominator -class MetricDice(tf.keras.metrics.Metric): - """An implementation of the dice loss function. - @source - Args: - smooth (float): The batch size (currently unused). - """ - def __init__(self, name="dice_coefficient", smooth=100, **kwargs): - super(MetricDice, self).__init__(name=name, **kwargs) - - self.param_smooth = smooth - - def call(self, y_true, y_pred): - ground_truth = tf.cast(y_true, dtype=tf.float32) - prediction = tf.cast(y_pred, dtype=tf.float32) - - return dice_coef(ground_truth, prediction, smooth=self.param_smooth) - - def get_config(self): - config = super(MetricDice, self).get_config() - config.update({ - "smooth": self.param_smooth, - }) - return config + +def metric_dice_coefficient(y_true, y_pred): + y_pred = tf.math.argmax(y_pred) + return dice_coefficient(y_true, y_pred) \ No newline at end of file