mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
train_mono: debug
this commit will generate a large amount of debug output.
This commit is contained in:
parent
f39e4ade70
commit
09f81b0746
3 changed files with 49 additions and 2 deletions
|
@ -2,7 +2,14 @@ import math
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class LossContrastive(tf.keras.losses.Loss):
|
class LossContrastive(tf.keras.losses.Loss):
|
||||||
|
"""Implements a contrastive loss function.
|
||||||
|
@warning: This does not function as it should.
|
||||||
|
Args:
|
||||||
|
weight_temperature (integer): The temperature weight (e.g. from LayerCheeseMultipleOut).
|
||||||
|
batch_size (integer): The batch size.
|
||||||
|
"""
|
||||||
def __init__(self, weight_temperature, batch_size):
|
def __init__(self, weight_temperature, batch_size):
|
||||||
super(LossContrastive, self).__init__()
|
super(LossContrastive, self).__init__()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
38
aimodel/src/lib/ai/components/LossCrossentropy.py
Normal file
38
aimodel/src/lib/ai/components/LossCrossentropy.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class LossCrossentropy(tf.keras.losses.Loss):
|
||||||
|
"""Wraps the cross-entropy loss function because it's buggy.
|
||||||
|
@warning: tf.keras.losses.CategoricalCrossentropy() isn't functioning as intended during training...
|
||||||
|
Args:
|
||||||
|
batch_size (integer): The batch size (currently unused).
|
||||||
|
"""
|
||||||
|
def __init__(self, batch_size):
|
||||||
|
super(LossCrossentropy, self).__init__()
|
||||||
|
|
||||||
|
self.param_batch_size = batch_size
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
result = tf.keras.metrics.categorical_crossentropy(y_true, y_pred)
|
||||||
|
result_reduce = tf.math.reduce_sum(result)
|
||||||
|
tf.print("DEBUG:TFPRINT:loss BEFORE_REDUCE", result, "AFTER_REDUCE", result_reduce)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super(LossCrossentropy, self).get_config()
|
||||||
|
config.update({
|
||||||
|
"batch_size": self.param_batch_size,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
weight_temperature = tf.Variable(name="loss_temperature", shape=1, initial_value=tf.constant([
|
||||||
|
math.log(1 / 0.07)
|
||||||
|
]))
|
||||||
|
loss = LossCrossentropy(weight_temperature=weight_temperature, batch_size=64)
|
||||||
|
|
||||||
|
tensor_input = tf.random.uniform([64, 2, 512])
|
||||||
|
print(loss(tf.constant(1), tensor_input))
|
|
@ -6,6 +6,7 @@ import tensorflow as tf
|
||||||
from .components.convnext import make_convnext
|
from .components.convnext import make_convnext
|
||||||
from .components.convnext_inverse import do_convnext_inverse
|
from .components.convnext_inverse import do_convnext_inverse
|
||||||
from .components.LayerStack2Image import LayerStack2Image
|
from .components.LayerStack2Image import LayerStack2Image
|
||||||
|
from .components.LossCrossentropy import LossCrossentropy
|
||||||
|
|
||||||
def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2):
|
def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2):
|
||||||
"""Makes a new rainfall / waterdepth mono model.
|
"""Makes a new rainfall / waterdepth mono model.
|
||||||
|
@ -71,7 +72,8 @@ def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext
|
||||||
|
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer="Adam",
|
optimizer="Adam",
|
||||||
loss=tf.keras.losses.CategoricalCrossentropy(),
|
loss=LossCrossentropy(batch_size=batch_size),
|
||||||
|
# loss=tf.keras.losses.CategoricalCrossentropy(),
|
||||||
metrics=[tf.keras.metrics.CategoricalAccuracy()]
|
metrics=[tf.keras.metrics.CategoricalAccuracy()]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue