mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
add optional PATH_CHECKPOINT env var
This commit is contained in:
parent
0e3de8f5fc
commit
93e663e45d
1 changed files with 114 additions and 108 deletions
|
@ -28,6 +28,8 @@ STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.
|
||||||
|
|
||||||
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||||
|
|
||||||
|
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
|
||||||
|
|
||||||
if not os.path.exists(DIR_OUTPUT):
|
if not os.path.exists(DIR_OUTPUT):
|
||||||
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
|
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
|
||||||
|
|
||||||
|
@ -38,6 +40,7 @@ logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||||
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
|
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
|
||||||
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
|
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
|
||||||
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
||||||
|
logger.info(f"> PATH_CHECKPOINT {PATH_CHECKPOINT}")
|
||||||
|
|
||||||
|
|
||||||
dataset_train, dataset_validate = dataset_mono(
|
dataset_train, dataset_validate = dataset_mono(
|
||||||
|
@ -60,73 +63,75 @@ logger.info("Validation Dataset:", dataset_validate)
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ██ ██ ██████ ██████ ███████ ███████
|
# ██ ██ ██████ ██████ ███████ ███████
|
||||||
|
|
||||||
def convolution_block(
|
|
||||||
block_input,
|
if PATH_CHECKPOINT is None:
|
||||||
num_filters=256,
|
def convolution_block(
|
||||||
kernel_size=3,
|
block_input,
|
||||||
dilation_rate=1,
|
num_filters=256,
|
||||||
padding="same",
|
kernel_size=3,
|
||||||
use_bias=False,
|
dilation_rate=1,
|
||||||
):
|
|
||||||
x = tf.keras.layers.Conv2D(
|
|
||||||
num_filters,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
dilation_rate=dilation_rate,
|
|
||||||
padding="same",
|
padding="same",
|
||||||
use_bias=use_bias,
|
use_bias=False,
|
||||||
kernel_initializer=tf.keras.initializers.HeNormal(),
|
):
|
||||||
)(block_input)
|
x = tf.keras.layers.Conv2D(
|
||||||
x = tf.keras.layers.BatchNormalization()(x)
|
num_filters,
|
||||||
return tf.nn.relu(x)
|
kernel_size=kernel_size,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
|
padding="same",
|
||||||
|
use_bias=use_bias,
|
||||||
|
kernel_initializer=tf.keras.initializers.HeNormal(),
|
||||||
|
)(block_input)
|
||||||
|
x = tf.keras.layers.BatchNormalization()(x)
|
||||||
|
return tf.nn.relu(x)
|
||||||
|
|
||||||
|
|
||||||
def DilatedSpatialPyramidPooling(dspp_input):
|
def DilatedSpatialPyramidPooling(dspp_input):
|
||||||
dims = dspp_input.shape
|
dims = dspp_input.shape
|
||||||
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
|
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
|
||||||
x = convolution_block(x, kernel_size=1, use_bias=True)
|
x = convolution_block(x, kernel_size=1, use_bias=True)
|
||||||
out_pool = tf.keras.layers.UpSampling2D(
|
out_pool = tf.keras.layers.UpSampling2D(
|
||||||
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
|
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
|
||||||
)(x)
|
)(x)
|
||||||
|
|
||||||
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
|
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
|
||||||
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
|
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
|
||||||
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
|
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
|
||||||
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
|
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
|
||||||
|
|
||||||
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
|
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
|
||||||
output = convolution_block(x, kernel_size=1)
|
output = convolution_block(x, kernel_size=1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
||||||
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
||||||
resnet50 = tf.keras.applications.ResNet50(
|
resnet50 = tf.keras.applications.ResNet50(
|
||||||
weights="imagenet" if num_channels == 3 else None,
|
weights="imagenet" if num_channels == 3 else None,
|
||||||
include_top=False, input_tensor=model_input
|
include_top=False, input_tensor=model_input
|
||||||
)
|
)
|
||||||
x = resnet50.get_layer("conv4_block6_2_relu").output
|
x = resnet50.get_layer("conv4_block6_2_relu").output
|
||||||
x = DilatedSpatialPyramidPooling(x)
|
x = DilatedSpatialPyramidPooling(x)
|
||||||
|
|
||||||
input_a = tf.keras.layers.UpSampling2D(
|
input_a = tf.keras.layers.UpSampling2D(
|
||||||
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
|
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
|
||||||
interpolation="bilinear",
|
interpolation="bilinear",
|
||||||
)(x)
|
)(x)
|
||||||
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
||||||
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
||||||
|
|
||||||
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
|
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
|
||||||
x = convolution_block(x)
|
x = convolution_block(x)
|
||||||
x = convolution_block(x)
|
x = convolution_block(x)
|
||||||
x = tf.keras.layers.UpSampling2D(
|
x = tf.keras.layers.UpSampling2D(
|
||||||
size=(image_size // x.shape[1], image_size // x.shape[2]),
|
size=(image_size // x.shape[1], image_size // x.shape[2]),
|
||||||
interpolation="bilinear",
|
interpolation="bilinear",
|
||||||
)(x)
|
)(x)
|
||||||
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
||||||
return tf.keras.Model(inputs=model_input, outputs=model_output)
|
return tf.keras.Model(inputs=model_input, outputs=model_output)
|
||||||
|
|
||||||
|
|
||||||
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
|
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
|
||||||
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
|
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,63 +142,64 @@ summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
|
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
|
||||||
|
|
||||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
if PATH_CHECKPOINT is None:
|
||||||
model.compile(
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
model.compile(
|
||||||
loss=loss,
|
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
||||||
metrics=["accuracy"],
|
loss=loss,
|
||||||
)
|
metrics=["accuracy"],
|
||||||
logger.info(">>> Beginning training")
|
)
|
||||||
history = model.fit(dataset_train,
|
logger.info(">>> Beginning training")
|
||||||
validation_data=dataset_validate,
|
history = model.fit(dataset_train,
|
||||||
epochs=25,
|
validation_data=dataset_validate,
|
||||||
callbacks=[
|
epochs=25,
|
||||||
tf.keras.callbacks.CSVLogger(
|
callbacks=[
|
||||||
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
tf.keras.callbacks.CSVLogger(
|
||||||
separator="\t"
|
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
||||||
),
|
separator="\t"
|
||||||
CallbackCustomModelCheckpoint(
|
|
||||||
model_to_checkpoint=model,
|
|
||||||
filepath=os.path.join(
|
|
||||||
DIR_OUTPUT,
|
|
||||||
"checkpoints"
|
|
||||||
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
|
||||||
),
|
),
|
||||||
monitor="loss"
|
CallbackCustomModelCheckpoint(
|
||||||
),
|
model_to_checkpoint=model,
|
||||||
],
|
filepath=os.path.join(
|
||||||
steps_per_epoch=STEPS_PER_EPOCH,
|
DIR_OUTPUT,
|
||||||
)
|
"checkpoints"
|
||||||
logger.info(">>> Training complete")
|
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
||||||
logger.info(">>> Plotting graphs")
|
),
|
||||||
|
monitor="loss"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
steps_per_epoch=STEPS_PER_EPOCH,
|
||||||
|
)
|
||||||
|
logger.info(">>> Training complete")
|
||||||
|
logger.info(">>> Plotting graphs")
|
||||||
|
|
||||||
plt.plot(history.history["loss"])
|
plt.plot(history.history["loss"])
|
||||||
plt.title("Training Loss")
|
plt.title("Training Loss")
|
||||||
plt.ylabel("loss")
|
plt.ylabel("loss")
|
||||||
plt.xlabel("epoch")
|
plt.xlabel("epoch")
|
||||||
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png"))
|
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png"))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
plt.plot(history.history["accuracy"])
|
plt.plot(history.history["accuracy"])
|
||||||
plt.title("Training Accuracy")
|
plt.title("Training Accuracy")
|
||||||
plt.ylabel("accuracy")
|
plt.ylabel("accuracy")
|
||||||
plt.xlabel("epoch")
|
plt.xlabel("epoch")
|
||||||
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
|
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
plt.plot(history.history["val_loss"])
|
plt.plot(history.history["val_loss"])
|
||||||
plt.title("Validation Loss")
|
plt.title("Validation Loss")
|
||||||
plt.ylabel("val_loss")
|
plt.ylabel("val_loss")
|
||||||
plt.xlabel("epoch")
|
plt.xlabel("epoch")
|
||||||
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
|
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
plt.plot(history.history["val_accuracy"])
|
plt.plot(history.history["val_accuracy"])
|
||||||
plt.title("Validation Accuracy")
|
plt.title("Validation Accuracy")
|
||||||
plt.ylabel("val_accuracy")
|
plt.ylabel("val_accuracy")
|
||||||
plt.xlabel("epoch")
|
plt.xlabel("epoch")
|
||||||
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
|
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue