diff --git a/aimodel/src/lib/ai/components/LossCrossentropy.py b/aimodel/src/lib/ai/components/LossCrossentropy.py index 359e86f..957454e 100644 --- a/aimodel/src/lib/ai/components/LossCrossentropy.py +++ b/aimodel/src/lib/ai/components/LossCrossentropy.py @@ -9,8 +9,8 @@ class LossCrossentropy(tf.keras.losses.Loss): Args: batch_size (integer): The batch size (currently unused). """ - def __init__(self, batch_size): - super(LossCrossentropy, self).__init__() + def __init__(self, batch_size, **kwargs): + super(LossCrossentropy, self).__init__(**kwargs) self.param_batch_size = batch_size