add optional PATH_CHECKPOINT env var

This commit is contained in:
Starbeamrainbowlabs 2023-01-11 17:20:19 +00:00
parent 0e3de8f5fc
commit 93e663e45d
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -28,6 +28,8 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
if not os.path.exists(DIR_OUTPUT): if not os.path.exists(DIR_OUTPUT):
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints")) os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
@ -38,6 +40,7 @@ logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}") logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}") logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}") logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
logger.info(f"> PATH_CHECKPOINT {PATH_CHECKPOINT}")
dataset_train, dataset_validate = dataset_mono( dataset_train, dataset_validate = dataset_mono(
@ -60,6 +63,8 @@ logger.info("Validation Dataset:", dataset_validate)
# ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████ # ██ ██ ██████ ██████ ███████ ███████
if PATH_CHECKPOINT is None:
def convolution_block( def convolution_block(
block_input, block_input,
num_filters=256, num_filters=256,
@ -137,6 +142,7 @@ summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████ # ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
if PATH_CHECKPOINT is None:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile( model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),