add optional PATH_CHECKPOINT env var

This commit is contained in:
Starbeamrainbowlabs 2023-01-11 17:20:19 +00:00
parent 0e3de8f5fc
commit 93e663e45d
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -28,6 +28,8 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
if not os.path.exists(DIR_OUTPUT): if not os.path.exists(DIR_OUTPUT):
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints")) os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
@ -38,6 +40,7 @@ logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}") logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}") logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}") logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
logger.info(f"> PATH_CHECKPOINT {PATH_CHECKPOINT}")
dataset_train, dataset_validate = dataset_mono( dataset_train, dataset_validate = dataset_mono(
@ -60,73 +63,75 @@ logger.info("Validation Dataset:", dataset_validate)
# ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████ # ██ ██ ██████ ██████ ███████ ███████
def convolution_block(
block_input, if PATH_CHECKPOINT is None:
num_filters=256, def convolution_block(
kernel_size=3, block_input,
dilation_rate=1, num_filters=256,
padding="same", kernel_size=3,
use_bias=False, dilation_rate=1,
):
x = tf.keras.layers.Conv2D(
num_filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same", padding="same",
use_bias=use_bias, use_bias=False,
kernel_initializer=tf.keras.initializers.HeNormal(), ):
)(block_input) x = tf.keras.layers.Conv2D(
x = tf.keras.layers.BatchNormalization()(x) num_filters,
return tf.nn.relu(x) kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same",
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(),
)(block_input)
x = tf.keras.layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input): def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape dims = dspp_input.shape
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input) x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True) x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = tf.keras.layers.UpSampling2D( out_pool = tf.keras.layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear", size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x) )(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1) out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6) out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12) out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18) out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18]) x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1) output = convolution_block(x, kernel_size=1)
return output return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3): def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
resnet50 = tf.keras.applications.ResNet50( resnet50 = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None, weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=model_input include_top=False, input_tensor=model_input
) )
x = resnet50.get_layer("conv4_block6_2_relu").output x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x) x = DilatedSpatialPyramidPooling(x)
input_a = tf.keras.layers.UpSampling2D( input_a = tf.keras.layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear", interpolation="bilinear",
)(x) )(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1) input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b]) x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x) x = convolution_block(x)
x = convolution_block(x) x = convolution_block(x)
x = tf.keras.layers.UpSampling2D( x = tf.keras.layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]), size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear", interpolation="bilinear",
)(x) )(x)
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x) model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return tf.keras.Model(inputs=model_input, outputs=model_output) return tf.keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8) model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt")) summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
@ -137,63 +142,64 @@ summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████ # ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) if PATH_CHECKPOINT is None:
model.compile( loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), model.compile(
loss=loss, optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
metrics=["accuracy"], loss=loss,
) metrics=["accuracy"],
logger.info(">>> Beginning training") )
history = model.fit(dataset_train, logger.info(">>> Beginning training")
validation_data=dataset_validate, history = model.fit(dataset_train,
epochs=25, validation_data=dataset_validate,
callbacks=[ epochs=25,
tf.keras.callbacks.CSVLogger( callbacks=[
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), tf.keras.callbacks.CSVLogger(
separator="\t" filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
), separator="\t"
CallbackCustomModelCheckpoint(
model_to_checkpoint=model,
filepath=os.path.join(
DIR_OUTPUT,
"checkpoints"
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
), ),
monitor="loss" CallbackCustomModelCheckpoint(
), model_to_checkpoint=model,
], filepath=os.path.join(
steps_per_epoch=STEPS_PER_EPOCH, DIR_OUTPUT,
) "checkpoints"
logger.info(">>> Training complete") "checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
logger.info(">>> Plotting graphs") ),
monitor="loss"
),
],
steps_per_epoch=STEPS_PER_EPOCH,
)
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
plt.plot(history.history["loss"]) plt.plot(history.history["loss"])
plt.title("Training Loss") plt.title("Training Loss")
plt.ylabel("loss") plt.ylabel("loss")
plt.xlabel("epoch") plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png")) plt.savefig(os.path.join(DIR_OUTPUT, "loss.png"))
plt.close() plt.close()
plt.plot(history.history["accuracy"]) plt.plot(history.history["accuracy"])
plt.title("Training Accuracy") plt.title("Training Accuracy")
plt.ylabel("accuracy") plt.ylabel("accuracy")
plt.xlabel("epoch") plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png")) plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
plt.close() plt.close()
plt.plot(history.history["val_loss"]) plt.plot(history.history["val_loss"])
plt.title("Validation Loss") plt.title("Validation Loss")
plt.ylabel("val_loss") plt.ylabel("val_loss")
plt.xlabel("epoch") plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png")) plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
plt.close() plt.close()
plt.plot(history.history["val_accuracy"]) plt.plot(history.history["val_accuracy"])
plt.title("Validation Accuracy") plt.title("Validation Accuracy")
plt.ylabel("val_accuracy") plt.ylabel("val_accuracy")
plt.xlabel("epoch") plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png")) plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
plt.close() plt.close()