From 93e663e45d80b8d3e6ba601c0be5a9c2642b3c7a Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 11 Jan 2023 17:20:19 +0000 Subject: [PATCH] add optional PATH_CHECKPOINT env var --- aimodel/src/deeplabv3_plus_test_rainfall.py | 222 ++++++++++---------- 1 file changed, 114 insertions(+), 108 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 5bf115f..2113c68 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -28,6 +28,8 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os. DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" +PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None + if not os.path.exists(DIR_OUTPUT): os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints")) @@ -38,6 +40,7 @@ logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}") logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}") logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}") logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}") +logger.info(f"> PATH_CHECKPOINT {PATH_CHECKPOINT}") dataset_train, dataset_validate = dataset_mono( @@ -60,73 +63,75 @@ logger.info("Validation Dataset:", dataset_validate) # ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██████ ██████ ███████ ███████ -def convolution_block( - block_input, - num_filters=256, - kernel_size=3, - dilation_rate=1, - padding="same", - use_bias=False, -): - x = tf.keras.layers.Conv2D( - num_filters, - kernel_size=kernel_size, - dilation_rate=dilation_rate, + +if PATH_CHECKPOINT is None: + def convolution_block( + block_input, + num_filters=256, + kernel_size=3, + dilation_rate=1, padding="same", - use_bias=use_bias, - kernel_initializer=tf.keras.initializers.HeNormal(), - )(block_input) - x = tf.keras.layers.BatchNormalization()(x) - return tf.nn.relu(x) + use_bias=False, + ): + x = tf.keras.layers.Conv2D( + num_filters, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + padding="same", + use_bias=use_bias, + kernel_initializer=tf.keras.initializers.HeNormal(), + )(block_input) + x = tf.keras.layers.BatchNormalization()(x) + return tf.nn.relu(x) -def DilatedSpatialPyramidPooling(dspp_input): - dims = dspp_input.shape - x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input) - x = convolution_block(x, kernel_size=1, use_bias=True) - out_pool = tf.keras.layers.UpSampling2D( - size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear", - )(x) + def DilatedSpatialPyramidPooling(dspp_input): + dims = dspp_input.shape + x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input) + x = convolution_block(x, kernel_size=1, use_bias=True) + out_pool = tf.keras.layers.UpSampling2D( + size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear", + )(x) - out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1) - out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6) - out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12) - out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18) + out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1) + out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6) + out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12) + out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18) - x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18]) - output = convolution_block(x, kernel_size=1) - return output + x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18]) + output = convolution_block(x, kernel_size=1) + return output -def DeeplabV3Plus(image_size, num_classes, num_channels=3): - model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) - resnet50 = tf.keras.applications.ResNet50( - weights="imagenet" if num_channels == 3 else None, - include_top=False, input_tensor=model_input - ) - x = resnet50.get_layer("conv4_block6_2_relu").output - x = DilatedSpatialPyramidPooling(x) + def DeeplabV3Plus(image_size, num_classes, num_channels=3): + model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) + resnet50 = tf.keras.applications.ResNet50( + weights="imagenet" if num_channels == 3 else None, + include_top=False, input_tensor=model_input + ) + x = resnet50.get_layer("conv4_block6_2_relu").output + x = DilatedSpatialPyramidPooling(x) - input_a = tf.keras.layers.UpSampling2D( - size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), - interpolation="bilinear", - )(x) - input_b = resnet50.get_layer("conv2_block3_2_relu").output - input_b = convolution_block(input_b, num_filters=48, kernel_size=1) + input_a = tf.keras.layers.UpSampling2D( + size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), + interpolation="bilinear", + )(x) + input_b = resnet50.get_layer("conv2_block3_2_relu").output + input_b = convolution_block(input_b, num_filters=48, kernel_size=1) - x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b]) - x = convolution_block(x) - x = convolution_block(x) - x = tf.keras.layers.UpSampling2D( - size=(image_size // x.shape[1], image_size // x.shape[2]), - interpolation="bilinear", - )(x) - model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x) - return tf.keras.Model(inputs=model_input, outputs=model_output) + x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b]) + x = convolution_block(x) + x = convolution_block(x) + x = tf.keras.layers.UpSampling2D( + size=(image_size // x.shape[1], image_size // x.shape[2]), + interpolation="bilinear", + )(x) + model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x) + return tf.keras.Model(inputs=model_input, outputs=model_output) -model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8) -summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt")) + model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8) + summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt")) @@ -137,63 +142,64 @@ summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt")) # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████ -loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) -model.compile( - optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), - loss=loss, - metrics=["accuracy"], -) -logger.info(">>> Beginning training") -history = model.fit(dataset_train, - validation_data=dataset_validate, - epochs=25, - callbacks=[ - tf.keras.callbacks.CSVLogger( - filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), - separator="\t" - ), - CallbackCustomModelCheckpoint( - model_to_checkpoint=model, - filepath=os.path.join( - DIR_OUTPUT, - "checkpoints" - "checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5" +if PATH_CHECKPOINT is None: + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=loss, + metrics=["accuracy"], + ) + logger.info(">>> Beginning training") + history = model.fit(dataset_train, + validation_data=dataset_validate, + epochs=25, + callbacks=[ + tf.keras.callbacks.CSVLogger( + filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), + separator="\t" ), - monitor="loss" - ), - ], - steps_per_epoch=STEPS_PER_EPOCH, -) -logger.info(">>> Training complete") -logger.info(">>> Plotting graphs") + CallbackCustomModelCheckpoint( + model_to_checkpoint=model, + filepath=os.path.join( + DIR_OUTPUT, + "checkpoints" + "checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5" + ), + monitor="loss" + ), + ], + steps_per_epoch=STEPS_PER_EPOCH, + ) + logger.info(">>> Training complete") + logger.info(">>> Plotting graphs") -plt.plot(history.history["loss"]) -plt.title("Training Loss") -plt.ylabel("loss") -plt.xlabel("epoch") -plt.savefig(os.path.join(DIR_OUTPUT, "loss.png")) -plt.close() + plt.plot(history.history["loss"]) + plt.title("Training Loss") + plt.ylabel("loss") + plt.xlabel("epoch") + plt.savefig(os.path.join(DIR_OUTPUT, "loss.png")) + plt.close() -plt.plot(history.history["accuracy"]) -plt.title("Training Accuracy") -plt.ylabel("accuracy") -plt.xlabel("epoch") -plt.savefig(os.path.join(DIR_OUTPUT, "acc.png")) -plt.close() + plt.plot(history.history["accuracy"]) + plt.title("Training Accuracy") + plt.ylabel("accuracy") + plt.xlabel("epoch") + plt.savefig(os.path.join(DIR_OUTPUT, "acc.png")) + plt.close() -plt.plot(history.history["val_loss"]) -plt.title("Validation Loss") -plt.ylabel("val_loss") -plt.xlabel("epoch") -plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png")) -plt.close() + plt.plot(history.history["val_loss"]) + plt.title("Validation Loss") + plt.ylabel("val_loss") + plt.xlabel("epoch") + plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png")) + plt.close() -plt.plot(history.history["val_accuracy"]) -plt.title("Validation Accuracy") -plt.ylabel("val_accuracy") -plt.xlabel("epoch") -plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png")) -plt.close() + plt.plot(history.history["val_accuracy"]) + plt.title("Validation Accuracy") + plt.ylabel("val_accuracy") + plt.xlabel("epoch") + plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png")) + plt.close()