research-rainfallradar/aimodel/src/deeplabv3_plus_test_rainfall.py

305 lines
10 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# @source https://keras.io/examples/vision/deeplabv3_plus/
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
from datetime import datetime
from loguru import logger
from lib.ai.helpers.summarywriter import summarywriter
2023-01-09 18:03:23 +00:00
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
import os
import cv2
import numpy as np
from glob import glob
from scipy.io import loadmat
import matplotlib.pyplot as plt
import tensorflow as tf
2023-01-05 18:26:33 +00:00
from lib.dataset.dataset_mono import dataset_mono
2022-12-16 19:52:59 +00:00
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
NUM_CLASSES = 2
2023-01-05 17:09:09 +00:00
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"]
2022-12-16 19:52:59 +00:00
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
2023-01-05 17:42:20 +00:00
PATH_COLOURMAP = os.environ["PATH_COLOURMAP"]
2022-12-16 19:52:59 +00:00
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
2023-01-12 18:54:39 +00:00
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 25
PREDICT_COUNT = int(os.environ["PREDICT_COUNT"]) if "PREDICT_COUNT" in os.environ else 4
2022-12-16 19:52:59 +00:00
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
2023-01-11 17:20:19 +00:00
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
if not os.path.exists(DIR_OUTPUT):
2023-01-09 18:03:23 +00:00
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
2022-12-16 19:52:59 +00:00
logger.info("DeepLabV3+ rainfall radar TEST")
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
2023-01-05 17:09:09 +00:00
logger.info(f"> DIR_RAINFALLWATER {DIR_RAINFALLWATER}")
2022-12-16 19:52:59 +00:00
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
2023-01-12 18:54:39 +00:00
logger.info(f"> EPOCHS {EPOCHS}")
2023-01-05 17:09:09 +00:00
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
2023-01-11 17:20:19 +00:00
logger.info(f"> PATH_CHECKPOINT {PATH_CHECKPOINT}")
2023-01-12 18:54:39 +00:00
logger.info(f"> PREDICT_COUNT {PREDICT_COUNT}")
dataset_train, dataset_validate = dataset_mono(
2023-01-05 17:09:09 +00:00
dirpath_input=DIR_RAINFALLWATER,
batch_size=BATCH_SIZE,
water_threshold=0.1,
rainfall_scale_up=2, # done BEFORE cropping to the below size
output_size=IMAGE_SIZE,
input_size="same",
filepath_heightmap=PATH_HEIGHTMAP,
)
logger.info("Train Dataset:", dataset_train)
logger.info("Validation Dataset:", dataset_validate)
# ███ ███ ██████ ██████ ███████ ██
# ████ ████ ██ ██ ██ ██ ██ ██
# ██ ████ ██ ██ ██ ██ ██ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████
2023-01-11 17:20:19 +00:00
if PATH_CHECKPOINT is None:
def convolution_block(
block_input,
num_filters=256,
kernel_size=3,
dilation_rate=1,
padding="same",
use_bias=False,
):
x = tf.keras.layers.Conv2D(
num_filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same",
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(),
)(block_input)
x = tf.keras.layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = tf.keras.layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
resnet50 = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=model_input
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = tf.keras.layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = tf.keras.layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear",
)(x)
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return tf.keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
2023-01-11 17:26:57 +00:00
else:
model = tf.keras.models.load_model(PATH_CHECKPOINT)
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
2023-01-11 17:20:19 +00:00
if PATH_CHECKPOINT is None:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=loss,
metrics=["accuracy"],
)
logger.info(">>> Beginning training")
history = model.fit(dataset_train,
validation_data=dataset_validate,
2023-01-12 18:54:39 +00:00
epochs=EPOCHS,
2023-01-11 17:20:19 +00:00
callbacks=[
tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
separator="\t"
2023-01-09 18:03:23 +00:00
),
2023-01-11 17:20:19 +00:00
CallbackCustomModelCheckpoint(
model_to_checkpoint=model,
filepath=os.path.join(
DIR_OUTPUT,
2023-01-11 17:28:13 +00:00
"checkpoints",
2023-01-11 17:20:19 +00:00
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
),
monitor="loss"
),
],
steps_per_epoch=STEPS_PER_EPOCH,
)
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
plt.plot(history.history["loss"])
plt.title("Training Loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png"))
plt.close()
plt.plot(history.history["accuracy"])
plt.title("Training Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
plt.close()
plt.plot(history.history["val_loss"])
plt.title("Validation Loss")
plt.ylabel("val_loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
plt.close()
plt.plot(history.history["val_accuracy"])
plt.title("Validation Accuracy")
plt.ylabel("val_accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
plt.close()
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
# ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██
# ██ ██ ██ ██ █████ █████ ██████ █████ ██ ██ ██ ██ █████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ████ ██ ███████ ██ ██ ███████ ██ ████ ██████ ███████
# Loading the Colormap
colormap = loadmat(
2023-01-05 19:17:44 +00:00
PATH_COLOURMAP
)["colormap"]
colormap = colormap * 100
colormap = colormap.astype(np.uint8)
def infer(model, image_tensor):
2023-01-05 19:17:44 +00:00
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
predictions = tf.squeeze(predictions)
predictions = tf.argmax(predictions, axis=2)
return predictions
def decode_segmentation_masks(mask, colormap, n_classes):
2023-01-05 19:17:44 +00:00
r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8)
for l in range(0, n_classes):
idx = mask == l
r[idx] = colormap[l, 0]
g[idx] = colormap[l, 1]
b[idx] = colormap[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
2023-01-05 17:09:09 +00:00
def get_overlay(image, coloured_mask):
2023-01-05 19:17:44 +00:00
image = tf.keras.preprocessing.image.array_to_img(image)
image = np.array(image).astype(np.uint8)
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
return overlay
def plot_samples_matplotlib(filepath, display_list):
2023-01-05 19:17:44 +00:00
for i in range(len(display_list)):
2023-01-12 18:03:06 +00:00
plt.subplot(1, len(display_list), i+1)
2023-01-05 19:17:44 +00:00
if display_list[i].shape[-1] == 3:
2023-01-12 17:56:59 +00:00
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
2023-01-05 19:17:44 +00:00
else:
2023-01-12 17:56:59 +00:00
plt.imshow(display_list[i])
plt.savefig(filepath, dpi=200)
2022-12-16 19:52:59 +00:00
def plot_predictions(filepath, input_items, colormap, model):
i = 0
for input_pair in input_items:
prediction_mask = infer(image_tensor=input_pair[0], model=model)
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
2023-01-12 19:20:22 +00:00
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
2023-01-10 19:19:30 +00:00
2023-01-05 19:17:44 +00:00
plot_samples_matplotlib(
2023-01-12 18:21:20 +00:00
filepath.replace("$$", str(i)),
[
# input_tensor,
2023-01-12 16:58:23 +00:00
input_pair[1], #label_colourmap
prediction_colormap
]
2023-01-05 19:17:44 +00:00
)
i += 1
def get_from_batched(dataset, count):
2022-12-16 19:52:59 +00:00
result = []
for batched in dataset:
items_input = tf.unstack(batched[0], axis=0)
items_label = tf.unstack(batched[1], axis=0)
for item in zip(items_input, items_label):
2022-12-16 19:52:59 +00:00
result.append(item)
if len(result) >= count:
return result
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
2023-01-12 18:54:39 +00:00
get_from_batched(dataset_train, PREDICT_COUNT),
2022-12-16 19:52:59 +00:00
colormap,
model=model
)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
2023-01-12 18:54:39 +00:00
get_from_batched(dataset_validate, PREDICT_COUNT),
2022-12-16 19:52:59 +00:00
colormap,
model=model
)