diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 5cca40b..fe9c7ef 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -18,7 +18,7 @@ import tensorflow as tf from lib.dataset.dataset_mono import dataset_mono from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice -from lib.ai.components.MetricDice import dice_coefficient +from lib.ai.components.MetricDice import metric_dice_coefficient from lib.ai.components.MetricSensitivity import sensitivity from lib.ai.components.MetricSpecificity import specificity @@ -189,7 +189,7 @@ if PATH_CHECKPOINT is None: loss=loss_fn, metrics=[ "accuracy", - dice_coefficient, + metric_dice_coefficient, tf.keras.metrics.MeanIoU(num_classes=2), sensitivity, # How many true positives were accurately predicted specificity # How many true negatives were accurately predicted? diff --git a/aimodel/src/lib/ai/components/MetricDice.py b/aimodel/src/lib/ai/components/MetricDice.py index d305e49..9ac35ec 100644 --- a/aimodel/src/lib/ai/components/MetricDice.py +++ b/aimodel/src/lib/ai/components/MetricDice.py @@ -14,32 +14,17 @@ def dice_coefficient(y_true, y_pred): Returns: tf.Tensor: The computed Dice coefficient. """ + + y_true = tf.cast(y_true, dtype=tf.float32) + y_pred = tf.cast(y_pred, dtype=tf.float32) + y_pred = tf.math.sigmoid(y_pred) numerator = 2 * tf.reduce_sum(y_true * y_pred) denominator = tf.reduce_sum(y_true + y_pred) - + return numerator / denominator -class MetricDice(tf.keras.metrics.Metric): - """An implementation of the dice loss function. - @source - Args: - smooth (float): The batch size (currently unused). - """ - def __init__(self, name="dice_coefficient", smooth=100, **kwargs): - super(MetricDice, self).__init__(name=name, **kwargs) - - self.param_smooth = smooth - - def call(self, y_true, y_pred): - ground_truth = tf.cast(y_true, dtype=tf.float32) - prediction = tf.cast(y_pred, dtype=tf.float32) - - return dice_coef(ground_truth, prediction, smooth=self.param_smooth) - - def get_config(self): - config = super(MetricDice, self).get_config() - config.update({ - "smooth": self.param_smooth, - }) - return config + +def metric_dice_coefficient(y_true, y_pred): + y_pred = tf.math.argmax(y_pred) + return dice_coefficient(y_true, y_pred) \ No newline at end of file