research-rainfallradar/aimodel/src/deeplabv3_plus_test_rainfall.py

357 lines
14 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# @source https://keras.io/examples/vision/deeplabv3_plus/
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
from datetime import datetime
from loguru import logger
from lib.ai.helpers.summarywriter import summarywriter
2023-01-09 18:03:23 +00:00
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
import os
import math
import cv2
import numpy as np
from glob import glob
from scipy.io import loadmat
import matplotlib.pyplot as plt
import tensorflow as tf
2023-01-05 18:26:33 +00:00
from lib.dataset.dataset_mono import dataset_mono
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
2023-03-03 22:44:49 +00:00
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
2023-03-03 20:37:22 +00:00
from lib.ai.components.MetricSpecificity import specificity
2023-03-03 22:44:49 +00:00
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
2023-01-05 18:26:33 +00:00
time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}")
2023-03-01 16:47:36 +00:00
# ███████ ███ ██ ██ ██ ██ ██████ ██████ ███ ██ ███ ███ ███████ ███ ██ ████████
# ██ ████ ██ ██ ██ ██ ██ ██ ██ ██ ████ ██ ████ ████ ██ ████ ██ ██
# █████ ██ ██ ██ ██ ██ ██ ██████ ██ ██ ██ ██ ██ ██ ████ ██ █████ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
2022-12-16 19:52:59 +00:00
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
NUM_CLASSES = 2
2023-01-05 17:09:09 +00:00
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"]
2022-12-16 19:52:59 +00:00
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
2023-01-05 17:42:20 +00:00
PATH_COLOURMAP = os.environ["PATH_COLOURMAP"]
2022-12-16 19:52:59 +00:00
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
2023-01-16 18:02:09 +00:00
REMOVE_ISOLATED_PIXELS = False if "NO_REMOVE_ISOLATED_PIXELS" in os.environ else True
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 50
LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice"
2023-01-13 18:29:39 +00:00
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
2022-12-16 19:52:59 +00:00
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
2023-01-11 17:20:19 +00:00
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
PREDICT_COUNT = int(os.environ["PREDICT_COUNT"]) if "PREDICT_COUNT" in os.environ else 25
2023-01-11 17:20:19 +00:00
2023-03-01 16:47:36 +00:00
# ~~~
if not os.path.exists(DIR_OUTPUT):
2023-01-09 18:03:23 +00:00
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
2023-03-01 16:47:36 +00:00
# ~~~
2023-03-01 16:47:36 +00:00
logger.info("DeepLabV3+ rainfall radar TEST")
for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT" ]:
logger.info(f"> {env_name} {str(globals()[env_name])}")
2023-03-01 16:47:36 +00:00
# ██████ █████ ████████ █████ ███████ ███████ ████████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ███████ ██ ███████ ███████ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
dataset_train, dataset_validate = dataset_mono(
2023-01-05 17:09:09 +00:00
dirpath_input=DIR_RAINFALLWATER,
batch_size=BATCH_SIZE,
water_threshold=0.1,
rainfall_scale_up=2, # done BEFORE cropping to the below size
output_size=IMAGE_SIZE,
input_size="same",
filepath_heightmap=PATH_HEIGHTMAP,
2023-01-13 17:26:38 +00:00
do_remove_isolated_pixels=REMOVE_ISOLATED_PIXELS
)
logger.info("Train Dataset:", dataset_train)
logger.info("Validation Dataset:", dataset_validate)
# ███ ███ ██████ ██████ ███████ ██
# ████ ████ ██ ██ ██ ██ ██ ██
# ██ ████ ██ ██ ██ ██ ██ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████
2023-01-11 17:20:19 +00:00
if PATH_CHECKPOINT is None:
def convolution_block(
block_input,
num_filters=256,
kernel_size=3,
dilation_rate=1,
padding="same",
use_bias=False,
):
x = tf.keras.layers.Conv2D(
num_filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same",
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(),
)(block_input)
x = tf.keras.layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = tf.keras.layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x)
2023-01-11 17:20:19 +00:00
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
2023-01-11 17:20:19 +00:00
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
2023-01-11 17:20:19 +00:00
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
2023-01-11 17:20:19 +00:00
resnet50 = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=x
2023-01-11 17:20:19 +00:00
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = tf.keras.layers.UpSampling2D(
2023-02-23 17:24:30 +00:00
size=(image_size // 4 // x.shape[1] * 2, image_size // 4 // x.shape[2] * 2), # <--- UPSAMPLE after pyramid
2023-01-11 17:20:19 +00:00
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = tf.keras.layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]), # <--- UPSAMPLE at end
2023-01-11 17:20:19 +00:00
interpolation="bilinear",
)(x)
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return tf.keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
2023-01-11 17:26:57 +00:00
else:
model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={
# Tell Tensorflow about our custom layers so that it can deserialise models that use them
"LossCrossEntropyDice": LossCrossEntropyDice,
"metric_dice_coefficient": dice_coefficient,
2023-03-09 19:43:35 +00:00
"sensitivity": sensitivity,
"specificity": specificity,
2023-03-09 19:34:45 +00:00
"one_hot_mean_iou": mean_iou
})
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
2023-01-11 17:20:19 +00:00
if PATH_CHECKPOINT is None:
loss_fn = None
if LOSS == "cross-entropy-dice":
loss_fn = LossCrossEntropyDice()
elif LOSS == "cross-entropy":
2023-01-13 18:47:29 +00:00
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
else:
raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")
2023-01-11 17:20:19 +00:00
model.compile(
2023-01-13 18:29:39 +00:00
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=loss_fn,
metrics=[
"accuracy",
dice_coefficient,
2023-03-03 22:44:49 +00:00
mean_iou(),
sensitivity(), # How many true positives were accurately predicted
2023-03-03 20:37:22 +00:00
specificity # How many true negatives were accurately predicted?
# TODO: Add IoU, F1, Precision, Recall, here.
],
2023-01-11 17:20:19 +00:00
)
logger.info(">>> Beginning training")
history = model.fit(dataset_train,
validation_data=dataset_validate,
2023-01-12 18:54:39 +00:00
epochs=EPOCHS,
2023-01-11 17:20:19 +00:00
callbacks=[
tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
separator="\t"
2023-01-09 18:03:23 +00:00
),
2023-01-11 17:20:19 +00:00
CallbackCustomModelCheckpoint(
model_to_checkpoint=model,
filepath=os.path.join(
DIR_OUTPUT,
2023-01-11 17:28:13 +00:00
"checkpoints",
2023-01-11 17:20:19 +00:00
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
),
monitor="loss"
),
],
steps_per_epoch=STEPS_PER_EPOCH,
)
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
plt.plot(history.history["loss"])
plt.title("Training Loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png"))
plt.close()
plt.plot(history.history["accuracy"])
plt.title("Training Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
plt.close()
plt.plot(history.history["val_loss"])
plt.title("Validation Loss")
plt.ylabel("val_loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
plt.close()
plt.plot(history.history["val_accuracy"])
plt.title("Validation Accuracy")
plt.ylabel("val_accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
plt.close()
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
# ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██
# ██ ██ ██ ██ █████ █████ ██████ █████ ██ ██ ██ ██ █████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ████ ██ ███████ ██ ██ ███████ ██ ████ ██████ ███████
# Loading the Colormap
colormap = loadmat(
2023-01-05 19:17:44 +00:00
PATH_COLOURMAP
)["colormap"]
colormap = colormap * 100
colormap = colormap.astype(np.uint8)
def infer(model, image_tensor, do_argmax=True):
2023-01-05 19:17:44 +00:00
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
predictions = tf.squeeze(predictions)
return predictions
def decode_segmentation_masks(mask, colormap, n_classes):
2023-01-05 19:17:44 +00:00
r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8)
for l in range(0, n_classes):
idx = mask == l
r[idx] = colormap[l, 0]
g[idx] = colormap[l, 1]
b[idx] = colormap[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
2023-01-05 17:09:09 +00:00
def get_overlay(image, coloured_mask):
2023-01-05 19:17:44 +00:00
image = tf.keras.preprocessing.image.array_to_img(image)
image = np.array(image).astype(np.uint8)
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
return overlay
def plot_samples_matplotlib(filepath, display_list):
plt.figure(figsize=(16, 8))
2023-01-05 19:17:44 +00:00
for i in range(len(display_list)):
plt.subplot(2, math.ceil(len(display_list) / 2), i+1)
2023-01-05 19:17:44 +00:00
if display_list[i].shape[-1] == 3:
2023-01-12 17:56:59 +00:00
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
2023-01-05 19:17:44 +00:00
else:
2023-01-12 17:56:59 +00:00
plt.imshow(display_list[i])
plt.colorbar()
plt.savefig(filepath, dpi=200)
2022-12-16 19:52:59 +00:00
def plot_predictions(filepath, input_items, colormap, model):
i = 0
for input_pair in input_items:
prediction_mask = infer(image_tensor=input_pair[0], model=model)
2023-03-09 19:44:39 +00:00
prediction_mask_argmax = tf.argmax(prediction_mask, axis=2)
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2)
2023-01-12 19:20:22 +00:00
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
2023-01-10 19:19:30 +00:00
2023-01-05 19:17:44 +00:00
plot_samples_matplotlib(
2023-01-12 18:21:20 +00:00
filepath.replace("$$", str(i)),
[
# input_tensor,
2023-03-09 19:13:25 +00:00
tf.math.reduce_max(input_pair[0][:,:,:-1], axis=-1), # rainfall only
input_pair[0][:,:,-1], # heightmap
input_pair[1], #label_colourmap,
prediction_mask[:,:,1],
prediction_colormap
]
2023-01-05 19:17:44 +00:00
)
i += 1
def get_from_batched(dataset, count):
2022-12-16 19:52:59 +00:00
result = []
for batched in dataset:
items_input = tf.unstack(batched[0], axis=0)
items_label = tf.unstack(batched[1], axis=0)
for item in zip(items_input, items_label):
2022-12-16 19:52:59 +00:00
result.append(item)
if len(result) >= count:
return result
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
2023-01-12 18:54:39 +00:00
get_from_batched(dataset_train, PREDICT_COUNT),
2022-12-16 19:52:59 +00:00
colormap,
model=model
)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
2023-01-12 18:54:39 +00:00
get_from_batched(dataset_validate, PREDICT_COUNT),
2022-12-16 19:52:59 +00:00
colormap,
model=model
)
logger.info(f"Complete at {str(datetime.now().isoformat())}, elapsed {str((datetime.now() - time_start).total_seconds())} seconds")