diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 23aeef2..cbe5f83 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -5,6 +5,7 @@ from datetime import datetime from loguru import logger from lib.ai.helpers.summarywriter import summarywriter +from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint import os import cv2 @@ -28,7 +29,7 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os. DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" if not os.path.exists(DIR_OUTPUT): - os.makedirs(DIR_OUTPUT) + os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints")) logger.info("DeepLabV3+ rainfall radar TEST") logger.info(f"> BATCH_SIZE {BATCH_SIZE}") @@ -150,7 +151,16 @@ history = model.fit(dataset_train, tf.keras.callbacks.CSVLogger( filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), separator="\t" - ) + ), + CallbackCustomModelCheckpoint( + model_to_checkpoint=model_predict, + filepath=os.path.join( + DIR_OUTPUT, + "checkpoints" + "checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5" + ), + monitor="loss" + ), ], steps_per_epoch=STEPS_PER_EPOCH, )