mono: switch loss from crossentropy to dice

This commit is contained in:
Starbeamrainbowlabs 2022-12-09 18:13:37 +00:00
parent 7fd7c750d6
commit 649c262960
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 63 additions and 5 deletions

View file

@ -0,0 +1,53 @@
import math
import tensorflow as tf
def dice_coef(y_true, y_pred, smooth=100):
"""Calculates the Dice coefficient.
@source https://stackoverflow.com/a/72264322/1460422
Args:
y_true (Tensor): The ground truth.
y_pred (Tensor): The predicted output.
smooth (float, optional): The smoothness of the output. Lower values = penalise the model more for mistakes to make it better at fine detail. Defaults to 100.
Returns:
Tensor: The dice coefficient.
"""
y_true_f = tf.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return dice
def dice_coef_loss(y_true, y_pred, **kwargs):
"""Turns the dice coefficient into a loss value.
NOTE: This is not the only option here. See also the other options in the source.
@source https://stackoverflow.com/a/72264322/1460422
Args:
y_true (Tensor): The ground truth
y_pred (Tensor): The predicted output.
Returns:
Tensor: The Dice coefficient, but as a loss value that decreases instead fo increases as the model learns.
"""
return -dice_coef(y_true, y_pred, **kwargs)
class LossDice(tf.keras.losses.Loss):
"""An implementation of the dice loss function.
Args:
smooth (float): The batch size (currently unused).
"""
def __init__(self, smooth=100, **kwargs):
super(LossCrossentropy, self).__init__(**kwargs)
self.param_smooth = smooth
def call(self, y_true, y_pred):
return dice_coef_loss(y_true, y_pred, smooth=self.param_smooth)
def get_config(self):
config = super(LossCrossentropy, self).get_config()
config.update({
"smooth": self.param_smooth,
})
return config

View file

@ -69,8 +69,11 @@ def model_rainfallwater_mono(metadata, model_arch_enc="convnext_xtiny", model_ar
# TODO: An attention layer here instead of a dense layer, with a skip connection perhaps?
logger.warning("Warning: TODO implement attention from https://ieeexplore.ieee.org/document/9076883")
layer_next = tf.keras.layers.Dense(32, activation="gelu")(layer_next)
layer_next = tf.keras.layers.Conv2D(water_bins, activation="gelu", kernel_size=1, padding="same")(layer_next)
layer_next = tf.keras.layers.Softmax(axis=-1)(layer_next)
# LOSS cross entropy
# layer_next = tf.keras.layers.Conv2D(water_bins, activation="gelu", kernel_size=1, padding="same")(layer_next)
# layer_next = tf.keras.layers.Softmax(axis=-1)(layer_next)
# LOSS dice
layer_next = tf.keras.layers.Conv2D(1, activation="gelu", kernel_size=1, padding="same")(layer_next)
model = tf.keras.Model(
inputs = layer_input,

View file

@ -69,9 +69,11 @@ def parse_item(metadata, shape_water_desired=[100,100], water_threshold=0.1, wat
print("DEBUG:dataset BEFORE_SQUEEZE water", water.shape)
water = tf.squeeze(water)
print("DEBUG:dataset AFTER_SQUEEZE water", water.shape)
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
# LOSS cross entropy
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
# LOSS dice
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
print("DEBUG DATASET_OUT:rainfall shape", rainfall.shape)
print("DEBUG DATASET_OUT:water shape", water.shape)