dlr: save checkpoints

This commit is contained in:
Starbeamrainbowlabs 2023-01-09 18:03:23 +00:00
parent 52cf66ca32
commit 581006cbe6
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -5,6 +5,7 @@
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
from lib.ai.helpers.summarywriter import summarywriter from lib.ai.helpers.summarywriter import summarywriter
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
import os import os
import cv2 import cv2
@ -28,7 +29,7 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
if not os.path.exists(DIR_OUTPUT): if not os.path.exists(DIR_OUTPUT):
os.makedirs(DIR_OUTPUT) os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
logger.info("DeepLabV3+ rainfall radar TEST") logger.info("DeepLabV3+ rainfall radar TEST")
logger.info(f"> BATCH_SIZE {BATCH_SIZE}") logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
@ -150,7 +151,16 @@ history = model.fit(dataset_train,
tf.keras.callbacks.CSVLogger( tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
separator="\t" separator="\t"
) ),
CallbackCustomModelCheckpoint(
model_to_checkpoint=model_predict,
filepath=os.path.join(
DIR_OUTPUT,
"checkpoints"
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
),
monitor="loss"
),
], ],
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
) )