tvt: implement CallbackExtraValidation, which allows for a third split

it should tie into Tensorflow's logging  just fine so long as it's the first callback in the queue.

***** TEST SCRIPT *****

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])

model.compile(optimizer='adam', loss='mse', metrics=['mae'])
X = np.random.random((100, 10))
y = np.random.random((100, 1))

split = 80
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(10)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(10)

history = model.fit(train_dataset,
	epochs=10,
	validation_data=val_dataset,
	callbacks=[
		CallbackExtraValidation({
			"test": val_dataset
		}, verbose=0),
		tf.keras.callbacks.CSVLogger("/dev/stdout", separator="\t")
	],
	verbose=0
)

print(f"DEBUG history {history}")
This commit is contained in:
Starbeamrainbowlabs 2024-08-30 18:07:17 +01:00
parent b5310304bd
commit 0761651ccf
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 50 additions and 0 deletions

View file

@ -27,6 +27,7 @@ from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coeffic
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
time_start = datetime.now() time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}") logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -259,6 +260,9 @@ if PATH_CHECKPOINT is None:
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way # test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
epochs=EPOCHS, epochs=EPOCHS,
callbacks=[ callbacks=[
CallbackExtraValidation(model, {
"test": dataset_test # Can be None because it handles that
}),
tf.keras.callbacks.CSVLogger( tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
separator="\t" separator="\t"

View file

@ -0,0 +1,46 @@
import tensorflow as tf
from loguru import logger
class CallbackExtraValidation(tf.keras.callbacks.Callback):
"""
A custom (keras) callback that to evaluate metrics on additional datasets during training.
These are passed back to Tensorflow/Keras by ~~abusing~~ updating the logs dictionary that's passed to us. If you update it with more metrics, then they get fed into the regular Tensorflow logging system :D
IMPORTANT: This MUST be the FIRST callback in the list! Otherwise it won't be executed before e.g. `tf.kkeras.callbacks.CSVLogger`.
TODO note to self blog about this because this was not as easy to figure out as it appears.
Ref kudos to <https://stackoverflow.com/a/47738812/1460422>, but you don't need to go to all that trouble :P
Args:
datasets (dict): A dictionary mapping dataset names to TensorFlow Dataset
objects.
verbose (str, optional): The verbosity level for the dataset evaluations. Basically the same as `verbose=VALUE` on `tf.keras.Model.fit()`. Default: `"auto"`.
"""
def __init__(self, datasets, verbose="auto"):
super(CallbackExtraValidation, self).__init__()
# self.model = model # apparently this exists by default??
self.datasets = datasets
self.verbose = verbose
def on_epoch_end(self, epoch, logs=None):
if logs == None:
logger.warning(
"[CallbackExtraValidation] logs is None! Can't do anything here.")
return False
for name, dataset in self.datasets.items():
if dataset is None:
logger.info(f"Skipping extra dataset {name} because it's None")
continue
metrics = self.model.evaluate(
dataset, verbose=self.verbose, return_dict=True)
for metric_name, metric_value in metrics.items():
logs[f"{name}_{metric_name}"] = metric_value
print(metrics)