mirror of
https://github.com/sbrl/research-rainfallradar
synced 2025-01-11 06:24:56 +00:00
tvt: implement CallbackExtraValidation, which allows for a third split
it should tie into Tensorflow's logging just fine so long as it's the first callback in the queue. ***** TEST SCRIPT ***** model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1) ]) model.compile(optimizer='adam', loss='mse', metrics=['mae']) X = np.random.random((100, 10)) y = np.random.random((100, 1)) split = 80 X_train, X_val = X[:split], X[split:] y_train, y_val = y[:split], y[split:] train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(10) val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(10) history = model.fit(train_dataset, epochs=10, validation_data=val_dataset, callbacks=[ CallbackExtraValidation({ "test": val_dataset }, verbose=0), tf.keras.callbacks.CSVLogger("/dev/stdout", separator="\t") ], verbose=0 ) print(f"DEBUG history {history}")
This commit is contained in:
parent
b5310304bd
commit
0761651ccf
2 changed files with 50 additions and 0 deletions
|
@ -27,6 +27,7 @@ from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coeffic
|
|||
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
||||
from lib.ai.components.MetricSpecificity import specificity
|
||||
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
||||
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
|
||||
|
||||
time_start = datetime.now()
|
||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||
|
@ -259,6 +260,9 @@ if PATH_CHECKPOINT is None:
|
|||
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
|
||||
epochs=EPOCHS,
|
||||
callbacks=[
|
||||
CallbackExtraValidation(model, {
|
||||
"test": dataset_test # Can be None because it handles that
|
||||
}),
|
||||
tf.keras.callbacks.CSVLogger(
|
||||
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
||||
separator="\t"
|
||||
|
|
46
aimodel/src/lib/ai/components/CallbackExtraValidation.py
Normal file
46
aimodel/src/lib/ai/components/CallbackExtraValidation.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import tensorflow as tf
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CallbackExtraValidation(tf.keras.callbacks.Callback):
|
||||
"""
|
||||
A custom (keras) callback that to evaluate metrics on additional datasets during training.
|
||||
|
||||
These are passed back to Tensorflow/Keras by ~~abusing~~ updating the logs dictionary that's passed to us. If you update it with more metrics, then they get fed into the regular Tensorflow logging system :D
|
||||
|
||||
IMPORTANT: This MUST be the FIRST callback in the list! Otherwise it won't be executed before e.g. `tf.kkeras.callbacks.CSVLogger`.
|
||||
|
||||
TODO note to self blog about this because this was not as easy to figure out as it appears.
|
||||
|
||||
Ref kudos to <https://stackoverflow.com/a/47738812/1460422>, but you don't need to go to all that trouble :P
|
||||
|
||||
Args:
|
||||
datasets (dict): A dictionary mapping dataset names to TensorFlow Dataset
|
||||
objects.
|
||||
verbose (str, optional): The verbosity level for the dataset evaluations. Basically the same as `verbose=VALUE` on `tf.keras.Model.fit()`. Default: `"auto"`.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, verbose="auto"):
|
||||
super(CallbackExtraValidation, self).__init__()
|
||||
# self.model = model # apparently this exists by default??
|
||||
self.datasets = datasets
|
||||
self.verbose = verbose
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if logs == None:
|
||||
logger.warning(
|
||||
"[CallbackExtraValidation] logs is None! Can't do anything here.")
|
||||
return False
|
||||
|
||||
for name, dataset in self.datasets.items():
|
||||
if dataset is None:
|
||||
logger.info(f"Skipping extra dataset {name} because it's None")
|
||||
continue
|
||||
|
||||
metrics = self.model.evaluate(
|
||||
dataset, verbose=self.verbose, return_dict=True)
|
||||
|
||||
for metric_name, metric_value in metrics.items():
|
||||
logs[f"{name}_{metric_name}"] = metric_value
|
||||
|
||||
print(metrics)
|
Loading…
Reference in a new issue