diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index daa2d64..8336553 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -27,6 +27,7 @@ from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coeffic from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity from lib.ai.components.MetricSpecificity import specificity from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou +from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation time_start = datetime.now() logger.info(f"Starting at {str(datetime.now().isoformat())}") @@ -259,6 +260,9 @@ if PATH_CHECKPOINT is None: # test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way epochs=EPOCHS, callbacks=[ + CallbackExtraValidation(model, { + "test": dataset_test # Can be None because it handles that + }), tf.keras.callbacks.CSVLogger( filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), separator="\t" diff --git a/aimodel/src/lib/ai/components/CallbackExtraValidation.py b/aimodel/src/lib/ai/components/CallbackExtraValidation.py new file mode 100644 index 0000000..4ea2f07 --- /dev/null +++ b/aimodel/src/lib/ai/components/CallbackExtraValidation.py @@ -0,0 +1,46 @@ +import tensorflow as tf +from loguru import logger + + +class CallbackExtraValidation(tf.keras.callbacks.Callback): + """ + A custom (keras) callback that to evaluate metrics on additional datasets during training. + + These are passed back to Tensorflow/Keras by ~~abusing~~ updating the logs dictionary that's passed to us. If you update it with more metrics, then they get fed into the regular Tensorflow logging system :D + + IMPORTANT: This MUST be the FIRST callback in the list! Otherwise it won't be executed before e.g. `tf.kkeras.callbacks.CSVLogger`. + + TODO note to self blog about this because this was not as easy to figure out as it appears. + + Ref kudos to , but you don't need to go to all that trouble :P + + Args: + datasets (dict): A dictionary mapping dataset names to TensorFlow Dataset + objects. + verbose (str, optional): The verbosity level for the dataset evaluations. Basically the same as `verbose=VALUE` on `tf.keras.Model.fit()`. Default: `"auto"`. + """ + + def __init__(self, datasets, verbose="auto"): + super(CallbackExtraValidation, self).__init__() + # self.model = model # apparently this exists by default?? + self.datasets = datasets + self.verbose = verbose + + def on_epoch_end(self, epoch, logs=None): + if logs == None: + logger.warning( + "[CallbackExtraValidation] logs is None! Can't do anything here.") + return False + + for name, dataset in self.datasets.items(): + if dataset is None: + logger.info(f"Skipping extra dataset {name} because it's None") + continue + + metrics = self.model.evaluate( + dataset, verbose=self.verbose, return_dict=True) + + for metric_name, metric_value in metrics.items(): + logs[f"{name}_{metric_name}"] = metric_value + + print(metrics)