diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index b9f1d88..1d733cd 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -161,7 +161,7 @@ if PATH_CHECKPOINT is None: if LOSS == "cross-entropy-dice": loss_fn = LossCrossEntropyDice() elif LOSS == "cross-entropy": - tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) else: raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")