2022-08-10 18:03:25 +00:00
|
|
|
import os
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
2022-09-06 18:48:46 +00:00
|
|
|
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
|
|
|
|
|
|
|
def make_callbacks(dirpath, model_predict):
|
2022-08-10 18:03:25 +00:00
|
|
|
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
|
|
|
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
|
|
|
|
|
|
|
if not os.path.exists(dirpath_checkpoints):
|
|
|
|
os.mkdir(dirpath_checkpoints)
|
|
|
|
|
|
|
|
return [
|
2022-09-06 18:48:46 +00:00
|
|
|
CallbackCustomModelCheckpoint(
|
|
|
|
model_to_checkpoint=model_predict,
|
2022-08-10 18:03:25 +00:00
|
|
|
filepath=os.path.join(
|
|
|
|
dirpath_checkpoints,
|
2022-10-12 16:12:07 +00:00
|
|
|
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
2022-08-10 18:03:25 +00:00
|
|
|
),
|
|
|
|
monitor="loss"
|
|
|
|
),
|
|
|
|
tf.keras.callbacks.CSVLogger(
|
|
|
|
filename=filepath_metrics,
|
|
|
|
separator="\t"
|
|
|
|
),
|
2022-10-31 18:36:28 +00:00
|
|
|
tf.keras.callbacks.ProgbarLogger(count_mode="steps") # batches
|
2022-08-10 18:03:25 +00:00
|
|
|
]
|