mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
actually use dice loss
This commit is contained in:
parent
649c262960
commit
e22c0981e6
1 changed files with 3 additions and 1 deletions
|
@ -7,6 +7,7 @@ from .components.convnext import make_convnext
|
|||
from .components.convnext_inverse import do_convnext_inverse
|
||||
from .components.LayerStack2Image import LayerStack2Image
|
||||
from .components.LossCrossentropy import LossCrossentropy
|
||||
from .components.LossDice import LossDice
|
||||
|
||||
def model_rainfallwater_mono(metadata, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2, learning_rate=None, heightmap_input=False):
|
||||
"""Makes a new rainfall / waterdepth mono model.
|
||||
|
@ -87,7 +88,8 @@ def model_rainfallwater_mono(metadata, model_arch_enc="convnext_xtiny", model_ar
|
|||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=LossCrossentropy(batch_size=batch_size),
|
||||
# loss=LossCrossentropy(batch_size=batch_size),
|
||||
loss=LossDice(),
|
||||
# loss=tf.keras.losses.CategoricalCrossentropy(),
|
||||
metrics=[tf.keras.metrics.CategoricalAccuracy()]
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue