research-rainfallradar/aimodel/src/deeplabv3_plus_test_rainfall.py

410 lines
16 KiB
Python
Executable file

#!/usr/bin/env python3
# @source https://keras.io/examples/vision/deeplabv3_plus/
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
from datetime import datetime
from loguru import logger
from lib.ai.helpers.summarywriter import summarywriter
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
import os
import io
import math
import json
import cv2
import numpy as np
from glob import glob
from scipy.io import loadmat
import matplotlib.pyplot as plt
import tensorflow as tf
import lib.primitives.env
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}")
# ███████ ███ ██ ██ ██ ██ ██████ ██████ ███ ██ ███ ███ ███████ ███ ██ ████████
# ██ ████ ██ ██ ██ ██ ██ ██ ██ ██ ████ ██ ████ ████ ██ ████ ██ ██
# █████ ██ ██ ██ ██ ██ ██ ██████ ██ ██ ██ ██ ██ ██ ████ ██ █████ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
IMAGE_SIZE = env.read("IMAGE_SIZE", int, 128) # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = env.read("BATCH_SIZE", int, 64)
NUM_CLASSES = 2
DIR_RAINFALLWATER = env.read("DIR_RAINFALLWATER", str)
PATH_HEIGHTMAP = env.read("PATH_HEIGHTMAP", str)
PATH_COLOURMAP = env.read("PATH_COLOURMAP", str)
PARALLEL_READS = env.read("PARALLEL_READS", float, 1.5)
STEPS_PER_EPOCH = env.read("STEPS_PER_EPOCH", int, None)
REMOVE_ISOLATED_PIXELS = env.read("NO_REMOVE_ISOLATED_PIXELS", bool, True)
EPOCHS = env.read("EPOCHS", int, 50)
LOSS = env.read("LOSS", str, "cross-entropy-dice") # other possible values: cross-entropy
DICE_LOG_COSH = env.read("DICE_LOG_COSH", bool, False)
LEARNING_RATE = env.read("LEARNING_RATE", float, 0.001)
WATER_THRESHOLD = env.read("WATER_THRESHOLD", float, 0.1)
UPSAMPLE = env.read("UPSAMPLE", int, 2)
SPLIT_VALIDATE = env.read("SPLIT_VALIDATE", float, 0.2)
SPLIT_TEST = env.read("SPLIT_TEST", float, 0)
STEPS_PER_EXECUTION = env.read("STEPS_PER_EXECUTION", int, 1)
JIT_COMPILE = env.read("JIT_COMPILE", bool, False)
DIR_OUTPUT = env.read("DIR_OUTPUT", str, f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST")
PATH_CHECKPOINT = env.read("PATH_CHECKPOINT", str, None)
PREDICT_COUNT = env.read("PREDICT_COUNT", int, 25)
PREDICT_AS_ONE = env.read("PREDICT_AS_ONE", bool, False)
# ~~~
env.val_dir_exists(os.path.join(DIR_OUTPUT, "checkpoints"), create=True)
# ~~~
logger.info("DeepLabV3+ rainfall radar TEST")
env.print_all(False)
# for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
# logger.info(f"> {env_name} {str(globals()[env_name])}")
# ██████ █████ ████████ █████ ███████ ███████ ████████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ███████ ██ ███████ ███████ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
if not PREDICT_AS_ONE:
dataset_train, dataset_validate, dataset_test = dataset_mono(
dirpath_input=DIR_RAINFALLWATER,
batch_size=BATCH_SIZE,
water_threshold=WATER_THRESHOLD,
rainfall_scale_up=2, # done BEFORE cropping to the below size
output_size=IMAGE_SIZE,
input_size="same",
filepath_heightmap=PATH_HEIGHTMAP,
do_remove_isolated_pixels=REMOVE_ISOLATED_PIXELS,
parallel_reads_multiplier=PARALLEL_READS,
percentage_validate=SPLIT_VALIDATE,
percentage_test=SPLIT_TESTs
)
logger.info("Train Dataset:", dataset_train)
logger.info("Validation Dataset:", dataset_validate)
logger.info("Test Dataset:", dataset_test)
else:
dataset_train = dataset_mono_predict(
dirpath_input=DIR_RAINFALLWATER,
batch_size=BATCH_SIZE,
water_threshold=WATER_THRESHOLD,
rainfall_scale_up=2, # done BEFORE cropping to the below size
output_size=IMAGE_SIZE,
input_size="same",
filepath_heightmap=PATH_HEIGHTMAP,
do_remove_isolated_pixels=REMOVE_ISOLATED_PIXELS
)
logger.info("Dataset AS_ONE:", dataset_train)
# ███ ███ ██████ ██████ ███████ ██
# ████ ████ ██ ██ ██ ██ ██ ██
# ██ ████ ██ ██ ██ ██ ██ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████
if PATH_CHECKPOINT is None:
def convolution_block(
block_input,
num_filters=256,
kernel_size=3,
dilation_rate=1,
padding="same",
use_bias=False,
):
x = tf.keras.layers.Conv2D(
num_filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same",
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(),
)(block_input)
x = tf.keras.layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = tf.keras.layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet", upsample=2):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
if upsample > 1:
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
else:
logger.info(f"[DeepLabV3+] Upsample disabled")
x = model_input
match backbone:
case "resnet":
backbone = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=x
)
case _:
raise Exception(f"Error: Unknown backbone {backbone}")
x = backbone.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
factor = 4 if upsample == 2 else 8 # else: upsample == 1. other values are not supported yet because maths
input_a = tf.keras.layers.UpSampling2D(
size=(image_size // factor // x.shape[1] * 2, image_size // factor // x.shape[2] * 2), # <--- UPSAMPLE after pyramid
interpolation="bilinear",
)(x)
input_b = backbone.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = tf.keras.layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]), # <--- UPSAMPLE at end
interpolation="bilinear",
)(x)
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return tf.keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(
image_size=IMAGE_SIZE,
num_classes=NUM_CLASSES,
upsample=UPSAMPLE,
num_channels=8
)
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
else:
model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={
# Tell Tensorflow about our custom layers so that it can deserialise models that use them
"LossCrossEntropyDice": LossCrossEntropyDice,
"metric_dice_coefficient": dice_coefficient,
"sensitivity": sensitivity,
"specificity": specificity,
"one_hot_mean_iou": mean_iou
})
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
def plot_metric(train, val, name, dir_output):
plt.plot(train, label=f"train_{name}")
plt.plot(val, label=f"val_{name}")
plt.title(name)
plt.xlabel("epoch")
plt.ylabel(name)
plt.savefig(os.path.join(dir_output, f"{name}.png"))
plt.close()
if PATH_CHECKPOINT is None:
loss_fn = None
if LOSS == "cross-entropy-dice":
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
elif LOSS == "cross-entropy":
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
else:
raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=loss_fn,
metrics=[
"accuracy",
dice_coefficient,
mean_iou(),
sensitivity(), # How many true positives were accurately predicted
specificity # How many true negatives were accurately predicted?
# TODO: Add IoU, F1, Precision, Recall, here.
],
steps_per_execution=STEPS_PER_EXECUTION,
jit_compile=JIT_COMPILE
)
logger.info(">>> Beginning training")
history = model.fit(dataset_train,
validation_data=dataset_validate,
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
epochs=EPOCHS,
callbacks=[
tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
separator="\t"
),
CallbackCustomModelCheckpoint(
model_to_checkpoint=model,
filepath=os.path.join(
DIR_OUTPUT,
"checkpoints",
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
),
monitor="loss"
),
],
steps_per_epoch=STEPS_PER_EPOCH,
)
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
plot_metric(history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT)
plot_metric(history.history["accuracy"], history.history["val_accuracy"], "accuracy", DIR_OUTPUT)
plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice", DIR_OUTPUT)
plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou", DIR_OUTPUT)
plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity", DIR_OUTPUT)
plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity", DIR_OUTPUT)
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
# ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██
# ██ ██ ██ ██ █████ █████ ██████ █████ ██ ██ ██ ██ █████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ████ ██ ███████ ██ ██ ███████ ██ ████ ██████ ███████
# Loading the Colormap
colormap = loadmat(
PATH_COLOURMAP
)["colormap"]
colormap = colormap * 100
colormap = colormap.astype(np.uint8)
def infer(model, image_tensor, do_argmax=True):
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
predictions = tf.squeeze(predictions)
return predictions
def decode_segmentation_masks(mask, colormap, n_classes):
r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8)
for l in range(0, n_classes):
idx = mask == l
r[idx] = colormap[l, 0]
g[idx] = colormap[l, 1]
b[idx] = colormap[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
def get_overlay(image, coloured_mask):
image = tf.keras.preprocessing.image.array_to_img(image)
image = np.array(image).astype(np.uint8)
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
return overlay
def plot_samples_matplotlib(filepath, display_list):
plt.figure(figsize=(16, 8))
for i in range(len(display_list)):
plt.subplot(2, math.ceil(len(display_list) / 2), i+1)
if display_list[i].shape[-1] == 3:
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
else:
plt.imshow(display_list[i])
plt.colorbar()
plt.savefig(filepath, dpi=200)
def save_samples(filepath, save_list):
handle = io.open(filepath, "a")
json.dump(save_list, handle)
handle.write("\n")
handle.close()
def plot_predictions(filepath, input_items, colormap, model):
filepath_jsonl = filepath.replace("_$$", "").replace(".png", ".jsonl")
if os.path.exists(filepath_jsonl):
os.truncate(filepath_jsonl, 0)
i = 0
for input_pair in input_items:
prediction_mask = infer(image_tensor=input_pair[0], model=model)
prediction_mask_argmax = tf.argmax(prediction_mask, axis=2)
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2)
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
plot_samples_matplotlib(
filepath.replace("$$", str(i)),
[
# input_tensor,
tf.math.reduce_max(input_pair[0][:,:,:-1], axis=-1), # rainfall only
input_pair[0][:,:,-1], # heightmap
input_pair[1], #label_colourmap,
prediction_mask[:,:,1],
prediction_colormap
]
)
save_samples(
filepath_jsonl,
prediction_mask.numpy().tolist()
)
i += 1
def get_from_batched(dataset, count):
result = []
for batched in dataset:
items_input = tf.unstack(batched[0], axis=0)
items_label = tf.unstack(batched[1], axis=0)
for item in zip(items_input, items_label):
result.append(item)
if len(result) >= count:
return result
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
get_from_batched(dataset_train, PREDICT_COUNT),
colormap,
model=model
)
if not PREDICT_AS_ONE:
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
get_from_batched(dataset_validate, PREDICT_COUNT),
colormap,
model=model
)
if dataset_test is not None:
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_test_$$.png"),
get_from_batched(dataset_test, PREDICT_COUNT),
colormap,
model=model
)
logger.info(f"Complete at {str(datetime.now().isoformat())}, elapsed {str((datetime.now() - time_start).total_seconds())} seconds")