mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
mono: switch loss from crossentropy to dice
This commit is contained in:
parent
7fd7c750d6
commit
649c262960
3 changed files with 63 additions and 5 deletions
53
aimodel/src/lib/ai/components/LossDice.py
Normal file
53
aimodel/src/lib/ai/components/LossDice.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
def dice_coef(y_true, y_pred, smooth=100):
|
||||
"""Calculates the Dice coefficient.
|
||||
@source https://stackoverflow.com/a/72264322/1460422
|
||||
Args:
|
||||
y_true (Tensor): The ground truth.
|
||||
y_pred (Tensor): The predicted output.
|
||||
smooth (float, optional): The smoothness of the output. Lower values = penalise the model more for mistakes to make it better at fine detail. Defaults to 100.
|
||||
Returns:
|
||||
Tensor: The dice coefficient.
|
||||
"""
|
||||
y_true_f = tf.flatten(y_true)
|
||||
y_pred_f = K.flatten(y_pred)
|
||||
intersection = K.sum(y_true_f * y_pred_f)
|
||||
dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
|
||||
return dice
|
||||
|
||||
def dice_coef_loss(y_true, y_pred, **kwargs):
|
||||
"""Turns the dice coefficient into a loss value.
|
||||
NOTE: This is not the only option here. See also the other options in the source.
|
||||
@source https://stackoverflow.com/a/72264322/1460422
|
||||
Args:
|
||||
y_true (Tensor): The ground truth
|
||||
y_pred (Tensor): The predicted output.
|
||||
Returns:
|
||||
Tensor: The Dice coefficient, but as a loss value that decreases instead fo increases as the model learns.
|
||||
"""
|
||||
return -dice_coef(y_true, y_pred, **kwargs)
|
||||
|
||||
|
||||
|
||||
class LossDice(tf.keras.losses.Loss):
|
||||
"""An implementation of the dice loss function.
|
||||
Args:
|
||||
smooth (float): The batch size (currently unused).
|
||||
"""
|
||||
def __init__(self, smooth=100, **kwargs):
|
||||
super(LossCrossentropy, self).__init__(**kwargs)
|
||||
|
||||
self.param_smooth = smooth
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
return dice_coef_loss(y_true, y_pred, smooth=self.param_smooth)
|
||||
|
||||
def get_config(self):
|
||||
config = super(LossCrossentropy, self).get_config()
|
||||
config.update({
|
||||
"smooth": self.param_smooth,
|
||||
})
|
||||
return config
|
|
@ -69,8 +69,11 @@ def model_rainfallwater_mono(metadata, model_arch_enc="convnext_xtiny", model_ar
|
|||
# TODO: An attention layer here instead of a dense layer, with a skip connection perhaps?
|
||||
logger.warning("Warning: TODO implement attention from https://ieeexplore.ieee.org/document/9076883")
|
||||
layer_next = tf.keras.layers.Dense(32, activation="gelu")(layer_next)
|
||||
layer_next = tf.keras.layers.Conv2D(water_bins, activation="gelu", kernel_size=1, padding="same")(layer_next)
|
||||
layer_next = tf.keras.layers.Softmax(axis=-1)(layer_next)
|
||||
# LOSS cross entropy
|
||||
# layer_next = tf.keras.layers.Conv2D(water_bins, activation="gelu", kernel_size=1, padding="same")(layer_next)
|
||||
# layer_next = tf.keras.layers.Softmax(axis=-1)(layer_next)
|
||||
# LOSS dice
|
||||
layer_next = tf.keras.layers.Conv2D(1, activation="gelu", kernel_size=1, padding="same")(layer_next)
|
||||
|
||||
model = tf.keras.Model(
|
||||
inputs = layer_input,
|
||||
|
|
|
@ -69,9 +69,11 @@ def parse_item(metadata, shape_water_desired=[100,100], water_threshold=0.1, wat
|
|||
print("DEBUG:dataset BEFORE_SQUEEZE water", water.shape)
|
||||
water = tf.squeeze(water)
|
||||
print("DEBUG:dataset AFTER_SQUEEZE water", water.shape)
|
||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
||||
water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
|
||||
|
||||
# LOSS cross entropy
|
||||
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
||||
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
|
||||
# LOSS dice
|
||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
||||
|
||||
print("DEBUG DATASET_OUT:rainfall shape", rainfall.shape)
|
||||
print("DEBUG DATASET_OUT:water shape", water.shape)
|
||||
|
|
Loading…
Reference in a new issue