mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-24 10:13:00 +00:00
dlr: save checkpoints
This commit is contained in:
parent
52cf66ca32
commit
581006cbe6
1 changed files with 12 additions and 2 deletions
|
@ -5,6 +5,7 @@
|
|||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from lib.ai.helpers.summarywriter import summarywriter
|
||||
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||
|
||||
import os
|
||||
import cv2
|
||||
|
@ -28,7 +29,7 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.
|
|||
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||
|
||||
if not os.path.exists(DIR_OUTPUT):
|
||||
os.makedirs(DIR_OUTPUT)
|
||||
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
|
||||
|
||||
logger.info("DeepLabV3+ rainfall radar TEST")
|
||||
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
|
||||
|
@ -150,7 +151,16 @@ history = model.fit(dataset_train,
|
|||
tf.keras.callbacks.CSVLogger(
|
||||
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
||||
separator="\t"
|
||||
)
|
||||
),
|
||||
CallbackCustomModelCheckpoint(
|
||||
model_to_checkpoint=model_predict,
|
||||
filepath=os.path.join(
|
||||
DIR_OUTPUT,
|
||||
"checkpoints"
|
||||
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
||||
),
|
||||
monitor="loss"
|
||||
),
|
||||
],
|
||||
steps_per_epoch=STEPS_PER_EPOCH,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue