2022-12-15 19:33:14 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
# @source https://keras.io/examples/vision/deeplabv3_plus/
|
|
|
|
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
from loguru import logger
|
2023-03-09 19:54:27 +00:00
|
|
|
|
2022-12-15 19:33:14 +00:00
|
|
|
from lib.ai.helpers.summarywriter import summarywriter
|
2023-01-09 18:03:23 +00:00
|
|
|
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
import os
|
2023-03-10 17:11:10 +00:00
|
|
|
import io
|
2023-03-09 19:54:27 +00:00
|
|
|
import math
|
2023-03-10 17:31:03 +00:00
|
|
|
import json
|
2022-12-15 19:33:14 +00:00
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
from glob import glob
|
|
|
|
from scipy.io import loadmat
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
2024-08-29 15:43:29 +00:00
|
|
|
import lib.primitives.env
|
2023-06-16 17:23:40 +00:00
|
|
|
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
2023-01-13 17:58:00 +00:00
|
|
|
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
2023-03-03 22:04:21 +00:00
|
|
|
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
2023-03-03 22:44:49 +00:00
|
|
|
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
2023-03-03 20:37:22 +00:00
|
|
|
from lib.ai.components.MetricSpecificity import specificity
|
2023-03-03 22:44:49 +00:00
|
|
|
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
tvt: implement CallbackExtraValidation, which allows for a third split
it should tie into Tensorflow's logging just fine so long as it's the first callback in the queue.
***** TEST SCRIPT *****
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
X = np.random.random((100, 10))
y = np.random.random((100, 1))
split = 80
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(10)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(10)
history = model.fit(train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=[
CallbackExtraValidation({
"test": val_dataset
}, verbose=0),
tf.keras.callbacks.CSVLogger("/dev/stdout", separator="\t")
],
verbose=0
)
print(f"DEBUG history {history}")
2024-08-30 17:07:17 +00:00
|
|
|
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
|
2023-01-05 18:26:33 +00:00
|
|
|
|
2023-01-16 17:30:20 +00:00
|
|
|
time_start = datetime.now()
|
|
|
|
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
|
|
|
|
2023-03-01 16:47:36 +00:00
|
|
|
|
|
|
|
# ███████ ███ ██ ██ ██ ██ ██████ ██████ ███ ██ ███ ███ ███████ ███ ██ ████████
|
|
|
|
# ██ ████ ██ ██ ██ ██ ██ ██ ██ ██ ████ ██ ████ ████ ██ ████ ██ ██
|
|
|
|
# █████ ██ ██ ██ ██ ██ ██ ██████ ██ ██ ██ ██ ██ ██ ████ ██ █████ ██ ██ ██ ██
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
|
|
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
|
|
|
|
|
2024-08-29 15:43:29 +00:00
|
|
|
IMAGE_SIZE = env.read("IMAGE_SIZE", int, 128) # was 512; 128 is the highest power of 2 that fits the data
|
|
|
|
BATCH_SIZE = env.read("BATCH_SIZE", int, 64)
|
2022-12-15 19:33:14 +00:00
|
|
|
NUM_CLASSES = 2
|
2024-08-29 15:43:29 +00:00
|
|
|
DIR_RAINFALLWATER = env.read("DIR_RAINFALLWATER", str)
|
|
|
|
PATH_HEIGHTMAP = env.read("PATH_HEIGHTMAP", str)
|
|
|
|
PATH_COLOURMAP = env.read("PATH_COLOURMAP", str)
|
|
|
|
PARALLEL_READS = env.read("PARALLEL_READS", float, 1.5)
|
|
|
|
STEPS_PER_EPOCH = env.read("STEPS_PER_EPOCH", int, None)
|
|
|
|
REMOVE_ISOLATED_PIXELS = env.read("NO_REMOVE_ISOLATED_PIXELS", bool, True)
|
|
|
|
EPOCHS = env.read("EPOCHS", int, 50)
|
|
|
|
LOSS = env.read("LOSS", str, "cross-entropy-dice") # other possible values: cross-entropy
|
|
|
|
DICE_LOG_COSH = env.read("DICE_LOG_COSH", bool, False)
|
|
|
|
LEARNING_RATE = env.read("LEARNING_RATE", float, 0.001)
|
|
|
|
WATER_THRESHOLD = env.read("WATER_THRESHOLD", float, 0.1)
|
|
|
|
UPSAMPLE = env.read("UPSAMPLE", int, 2)
|
|
|
|
SPLIT_VALIDATE = env.read("SPLIT_VALIDATE", float, 0.2)
|
|
|
|
SPLIT_TEST = env.read("SPLIT_TEST", float, 0)
|
|
|
|
|
|
|
|
STEPS_PER_EXECUTION = env.read("STEPS_PER_EXECUTION", int, 1)
|
|
|
|
JIT_COMPILE = env.read("JIT_COMPILE", bool, False)
|
|
|
|
DIR_OUTPUT = env.read("DIR_OUTPUT", str, f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST")
|
|
|
|
PATH_CHECKPOINT = env.read("PATH_CHECKPOINT", str, None)
|
|
|
|
PREDICT_COUNT = env.read("PREDICT_COUNT", int, 25)
|
|
|
|
PREDICT_AS_ONE = env.read("PREDICT_AS_ONE", bool, False)
|
2024-08-29 18:33:40 +00:00
|
|
|
|
2023-03-01 16:47:36 +00:00
|
|
|
# ~~~
|
|
|
|
|
2024-08-29 15:43:29 +00:00
|
|
|
env.val_dir_exists(os.path.join(DIR_OUTPUT, "checkpoints"), create=True)
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2023-03-01 16:47:36 +00:00
|
|
|
# ~~~
|
2023-01-13 17:58:00 +00:00
|
|
|
|
2023-03-01 16:47:36 +00:00
|
|
|
logger.info("DeepLabV3+ rainfall radar TEST")
|
2024-08-29 15:43:29 +00:00
|
|
|
env.print_all(False)
|
|
|
|
# for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
|
|
|
|
# logger.info(f"> {env_name} {str(globals()[env_name])}")
|
2023-01-13 17:58:00 +00:00
|
|
|
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2023-03-01 16:47:36 +00:00
|
|
|
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
|
|
# ██ ██ ███████ ██ ███████ ███████ █████ ██
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
|
|
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2023-06-16 17:23:40 +00:00
|
|
|
if not PREDICT_AS_ONE:
|
2024-08-29 18:33:40 +00:00
|
|
|
dataset_train, dataset_validate, dataset_test = dataset_mono(
|
2023-06-16 17:23:40 +00:00
|
|
|
dirpath_input=DIR_RAINFALLWATER,
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
water_threshold=WATER_THRESHOLD,
|
|
|
|
rainfall_scale_up=2, # done BEFORE cropping to the below size
|
|
|
|
output_size=IMAGE_SIZE,
|
|
|
|
input_size="same",
|
|
|
|
filepath_heightmap=PATH_HEIGHTMAP,
|
2023-11-30 16:33:22 +00:00
|
|
|
do_remove_isolated_pixels=REMOVE_ISOLATED_PIXELS,
|
2024-08-29 18:33:40 +00:00
|
|
|
parallel_reads_multiplier=PARALLEL_READS,
|
|
|
|
percentage_validate=SPLIT_VALIDATE,
|
|
|
|
percentage_test=SPLIT_TESTs
|
2023-06-16 17:23:40 +00:00
|
|
|
)
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2023-06-16 17:23:40 +00:00
|
|
|
logger.info("Train Dataset:", dataset_train)
|
|
|
|
logger.info("Validation Dataset:", dataset_validate)
|
2024-08-29 18:33:40 +00:00
|
|
|
logger.info("Test Dataset:", dataset_test)
|
2023-06-16 17:23:40 +00:00
|
|
|
else:
|
|
|
|
dataset_train = dataset_mono_predict(
|
|
|
|
dirpath_input=DIR_RAINFALLWATER,
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
water_threshold=WATER_THRESHOLD,
|
|
|
|
rainfall_scale_up=2, # done BEFORE cropping to the below size
|
|
|
|
output_size=IMAGE_SIZE,
|
|
|
|
input_size="same",
|
|
|
|
filepath_heightmap=PATH_HEIGHTMAP,
|
|
|
|
do_remove_isolated_pixels=REMOVE_ISOLATED_PIXELS
|
|
|
|
)
|
|
|
|
logger.info("Dataset AS_ONE:", dataset_train)
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
# ███ ███ ██████ ██████ ███████ ██
|
|
|
|
# ████ ████ ██ ██ ██ ██ ██ ██
|
|
|
|
# ██ ████ ██ ██ ██ ██ ██ █████ ██
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
|
|
# ██ ██ ██████ ██████ ███████ ███████
|
|
|
|
|
|
|
|
|
2023-01-11 17:20:19 +00:00
|
|
|
if PATH_CHECKPOINT is None:
|
|
|
|
def convolution_block(
|
|
|
|
block_input,
|
|
|
|
num_filters=256,
|
|
|
|
kernel_size=3,
|
|
|
|
dilation_rate=1,
|
|
|
|
padding="same",
|
|
|
|
use_bias=False,
|
|
|
|
):
|
|
|
|
x = tf.keras.layers.Conv2D(
|
|
|
|
num_filters,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
dilation_rate=dilation_rate,
|
|
|
|
padding="same",
|
|
|
|
use_bias=use_bias,
|
|
|
|
kernel_initializer=tf.keras.initializers.HeNormal(),
|
|
|
|
)(block_input)
|
|
|
|
x = tf.keras.layers.BatchNormalization()(x)
|
|
|
|
return tf.nn.relu(x)
|
|
|
|
|
|
|
|
|
|
|
|
def DilatedSpatialPyramidPooling(dspp_input):
|
|
|
|
dims = dspp_input.shape
|
|
|
|
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
|
|
|
|
x = convolution_block(x, kernel_size=1, use_bias=True)
|
|
|
|
out_pool = tf.keras.layers.UpSampling2D(
|
|
|
|
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
|
|
|
|
)(x)
|
2023-02-23 16:47:00 +00:00
|
|
|
|
2023-01-11 17:20:19 +00:00
|
|
|
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
|
|
|
|
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
|
|
|
|
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
|
|
|
|
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
|
2023-02-23 16:47:00 +00:00
|
|
|
|
2023-01-11 17:20:19 +00:00
|
|
|
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
|
|
|
|
output = convolution_block(x, kernel_size=1)
|
|
|
|
return output
|
2023-02-23 16:47:00 +00:00
|
|
|
|
|
|
|
|
2023-05-04 16:40:16 +00:00
|
|
|
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet", upsample=2):
|
2023-01-11 17:20:19 +00:00
|
|
|
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
2023-05-04 16:40:16 +00:00
|
|
|
if upsample > 1:
|
|
|
|
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
|
|
|
|
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
|
|
|
else:
|
|
|
|
logger.info(f"[DeepLabV3+] Upsample disabled")
|
|
|
|
x = model_input
|
2023-03-14 21:51:41 +00:00
|
|
|
|
|
|
|
match backbone:
|
|
|
|
case "resnet":
|
|
|
|
backbone = tf.keras.applications.ResNet50(
|
|
|
|
weights="imagenet" if num_channels == 3 else None,
|
|
|
|
include_top=False, input_tensor=x
|
|
|
|
)
|
|
|
|
case _:
|
|
|
|
raise Exception(f"Error: Unknown backbone {backbone}")
|
|
|
|
|
|
|
|
x = backbone.get_layer("conv4_block6_2_relu").output
|
2023-01-11 17:20:19 +00:00
|
|
|
x = DilatedSpatialPyramidPooling(x)
|
2023-05-04 18:54:51 +00:00
|
|
|
|
2023-05-04 18:57:02 +00:00
|
|
|
factor = 4 if upsample == 2 else 8 # else: upsample == 1. other values are not supported yet because maths
|
2023-01-11 17:20:19 +00:00
|
|
|
input_a = tf.keras.layers.UpSampling2D(
|
2023-05-04 18:57:02 +00:00
|
|
|
size=(image_size // factor // x.shape[1] * 2, image_size // factor // x.shape[2] * 2), # <--- UPSAMPLE after pyramid
|
2023-01-11 17:20:19 +00:00
|
|
|
interpolation="bilinear",
|
|
|
|
)(x)
|
2023-03-14 21:51:41 +00:00
|
|
|
input_b = backbone.get_layer("conv2_block3_2_relu").output
|
2023-01-11 17:20:19 +00:00
|
|
|
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
|
|
|
|
|
|
|
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
|
|
|
|
x = convolution_block(x)
|
|
|
|
x = convolution_block(x)
|
|
|
|
x = tf.keras.layers.UpSampling2D(
|
2023-02-23 16:47:00 +00:00
|
|
|
size=(image_size // x.shape[1], image_size // x.shape[2]), # <--- UPSAMPLE at end
|
2023-01-11 17:20:19 +00:00
|
|
|
interpolation="bilinear",
|
|
|
|
)(x)
|
|
|
|
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
|
|
|
return tf.keras.Model(inputs=model_input, outputs=model_output)
|
|
|
|
|
2023-05-04 16:40:16 +00:00
|
|
|
model = DeeplabV3Plus(
|
|
|
|
image_size=IMAGE_SIZE,
|
|
|
|
num_classes=NUM_CLASSES,
|
|
|
|
upsample=UPSAMPLE,
|
|
|
|
num_channels=8
|
|
|
|
)
|
2023-01-11 17:20:19 +00:00
|
|
|
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
|
2023-01-11 17:26:57 +00:00
|
|
|
else:
|
2023-03-01 17:19:10 +00:00
|
|
|
model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={
|
|
|
|
# Tell Tensorflow about our custom layers so that it can deserialise models that use them
|
2023-03-03 19:34:55 +00:00
|
|
|
"LossCrossEntropyDice": LossCrossEntropyDice,
|
2023-03-09 19:26:57 +00:00
|
|
|
"metric_dice_coefficient": dice_coefficient,
|
2023-03-09 19:43:35 +00:00
|
|
|
"sensitivity": sensitivity,
|
2023-03-09 19:26:57 +00:00
|
|
|
"specificity": specificity,
|
2023-03-09 19:34:45 +00:00
|
|
|
"one_hot_mean_iou": mean_iou
|
2023-03-01 17:19:10 +00:00
|
|
|
})
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
|
|
|
|
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
|
|
|
|
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
|
|
|
|
|
2023-03-22 17:41:34 +00:00
|
|
|
def plot_metric(train, val, name, dir_output):
|
|
|
|
plt.plot(train, label=f"train_{name}")
|
|
|
|
plt.plot(val, label=f"val_{name}")
|
|
|
|
plt.title(name)
|
|
|
|
plt.xlabel("epoch")
|
|
|
|
plt.ylabel(name)
|
|
|
|
plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
|
|
|
plt.close()
|
|
|
|
|
2023-01-11 17:20:19 +00:00
|
|
|
if PATH_CHECKPOINT is None:
|
2023-01-13 17:58:00 +00:00
|
|
|
loss_fn = None
|
|
|
|
if LOSS == "cross-entropy-dice":
|
2023-03-10 20:24:13 +00:00
|
|
|
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
|
2023-01-13 17:58:00 +00:00
|
|
|
elif LOSS == "cross-entropy":
|
2023-01-13 18:47:29 +00:00
|
|
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
2023-01-13 17:58:00 +00:00
|
|
|
else:
|
|
|
|
raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")
|
|
|
|
|
2023-01-11 17:20:19 +00:00
|
|
|
model.compile(
|
2023-01-13 18:29:39 +00:00
|
|
|
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
|
2023-01-13 17:58:00 +00:00
|
|
|
loss=loss_fn,
|
2023-03-03 19:34:55 +00:00
|
|
|
metrics=[
|
|
|
|
"accuracy",
|
2023-03-03 22:04:21 +00:00
|
|
|
dice_coefficient,
|
2023-03-03 22:44:49 +00:00
|
|
|
mean_iou(),
|
|
|
|
sensitivity(), # How many true positives were accurately predicted
|
2023-03-03 20:37:22 +00:00
|
|
|
specificity # How many true negatives were accurately predicted?
|
2023-03-03 19:34:55 +00:00
|
|
|
# TODO: Add IoU, F1, Precision, Recall, here.
|
|
|
|
],
|
2023-05-04 17:22:18 +00:00
|
|
|
steps_per_execution=STEPS_PER_EXECUTION,
|
|
|
|
jit_compile=JIT_COMPILE
|
2023-01-11 17:20:19 +00:00
|
|
|
)
|
|
|
|
logger.info(">>> Beginning training")
|
|
|
|
history = model.fit(dataset_train,
|
|
|
|
validation_data=dataset_validate,
|
2024-08-29 18:33:40 +00:00
|
|
|
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
|
2023-01-12 18:54:39 +00:00
|
|
|
epochs=EPOCHS,
|
2023-01-11 17:20:19 +00:00
|
|
|
callbacks=[
|
tvt: implement CallbackExtraValidation, which allows for a third split
it should tie into Tensorflow's logging just fine so long as it's the first callback in the queue.
***** TEST SCRIPT *****
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
X = np.random.random((100, 10))
y = np.random.random((100, 1))
split = 80
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(10)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(10)
history = model.fit(train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=[
CallbackExtraValidation({
"test": val_dataset
}, verbose=0),
tf.keras.callbacks.CSVLogger("/dev/stdout", separator="\t")
],
verbose=0
)
print(f"DEBUG history {history}")
2024-08-30 17:07:17 +00:00
|
|
|
CallbackExtraValidation(model, {
|
|
|
|
"test": dataset_test # Can be None because it handles that
|
|
|
|
}),
|
2023-01-11 17:20:19 +00:00
|
|
|
tf.keras.callbacks.CSVLogger(
|
|
|
|
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
|
|
|
separator="\t"
|
2023-01-09 18:03:23 +00:00
|
|
|
),
|
2023-01-11 17:20:19 +00:00
|
|
|
CallbackCustomModelCheckpoint(
|
|
|
|
model_to_checkpoint=model,
|
|
|
|
filepath=os.path.join(
|
|
|
|
DIR_OUTPUT,
|
2023-01-11 17:28:13 +00:00
|
|
|
"checkpoints",
|
2023-01-11 17:20:19 +00:00
|
|
|
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
|
|
|
),
|
|
|
|
monitor="loss"
|
|
|
|
),
|
|
|
|
],
|
|
|
|
steps_per_epoch=STEPS_PER_EPOCH,
|
|
|
|
)
|
|
|
|
logger.info(">>> Training complete")
|
|
|
|
logger.info(">>> Plotting graphs")
|
2023-03-22 17:41:34 +00:00
|
|
|
|
2023-05-07 18:00:02 +00:00
|
|
|
plot_metric(history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT)
|
2023-05-19 21:00:23 +00:00
|
|
|
plot_metric(history.history["accuracy"], history.history["val_accuracy"], "accuracy", DIR_OUTPUT)
|
2023-05-07 18:00:02 +00:00
|
|
|
plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice", DIR_OUTPUT)
|
|
|
|
plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou", DIR_OUTPUT)
|
|
|
|
plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity", DIR_OUTPUT)
|
|
|
|
plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity", DIR_OUTPUT)
|
2023-03-22 17:41:34 +00:00
|
|
|
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
|
|
|
|
# ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██
|
|
|
|
# ██ ██ ██ ██ █████ █████ ██████ █████ ██ ██ ██ ██ █████
|
|
|
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
|
|
# ██ ██ ████ ██ ███████ ██ ██ ███████ ██ ████ ██████ ███████
|
|
|
|
|
|
|
|
# Loading the Colormap
|
|
|
|
colormap = loadmat(
|
2023-01-05 19:17:44 +00:00
|
|
|
PATH_COLOURMAP
|
2022-12-15 19:33:14 +00:00
|
|
|
)["colormap"]
|
|
|
|
colormap = colormap * 100
|
|
|
|
colormap = colormap.astype(np.uint8)
|
|
|
|
|
|
|
|
|
2023-03-09 18:54:28 +00:00
|
|
|
def infer(model, image_tensor, do_argmax=True):
|
2023-01-05 19:17:44 +00:00
|
|
|
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
|
|
|
|
predictions = tf.squeeze(predictions)
|
|
|
|
return predictions
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
def decode_segmentation_masks(mask, colormap, n_classes):
|
2023-01-05 19:17:44 +00:00
|
|
|
r = np.zeros_like(mask).astype(np.uint8)
|
|
|
|
g = np.zeros_like(mask).astype(np.uint8)
|
|
|
|
b = np.zeros_like(mask).astype(np.uint8)
|
|
|
|
for l in range(0, n_classes):
|
|
|
|
idx = mask == l
|
|
|
|
r[idx] = colormap[l, 0]
|
|
|
|
g[idx] = colormap[l, 1]
|
|
|
|
b[idx] = colormap[l, 2]
|
|
|
|
rgb = np.stack([r, g, b], axis=2)
|
|
|
|
return rgb
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
|
2023-01-05 17:09:09 +00:00
|
|
|
def get_overlay(image, coloured_mask):
|
2023-01-05 19:17:44 +00:00
|
|
|
image = tf.keras.preprocessing.image.array_to_img(image)
|
|
|
|
image = np.array(image).astype(np.uint8)
|
|
|
|
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
|
|
|
|
return overlay
|
2022-12-15 19:33:14 +00:00
|
|
|
|
|
|
|
|
2023-01-12 18:43:48 +00:00
|
|
|
def plot_samples_matplotlib(filepath, display_list):
|
2023-03-09 19:54:27 +00:00
|
|
|
plt.figure(figsize=(16, 8))
|
2023-01-05 19:17:44 +00:00
|
|
|
for i in range(len(display_list)):
|
2023-03-09 19:54:27 +00:00
|
|
|
plt.subplot(2, math.ceil(len(display_list) / 2), i+1)
|
2023-01-05 19:17:44 +00:00
|
|
|
if display_list[i].shape[-1] == 3:
|
2023-01-12 17:56:59 +00:00
|
|
|
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
|
2023-01-05 19:17:44 +00:00
|
|
|
else:
|
2023-01-12 17:56:59 +00:00
|
|
|
plt.imshow(display_list[i])
|
2023-03-09 18:54:28 +00:00
|
|
|
plt.colorbar()
|
2023-01-12 18:43:48 +00:00
|
|
|
plt.savefig(filepath, dpi=200)
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2023-03-10 17:07:44 +00:00
|
|
|
def save_samples(filepath, save_list):
|
|
|
|
handle = io.open(filepath, "a")
|
|
|
|
json.dump(save_list, handle)
|
|
|
|
handle.write("\n")
|
|
|
|
handle.close()
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2022-12-16 19:52:59 +00:00
|
|
|
def plot_predictions(filepath, input_items, colormap, model):
|
2023-03-10 17:11:10 +00:00
|
|
|
filepath_jsonl = filepath.replace("_$$", "").replace(".png", ".jsonl")
|
2023-03-10 17:14:06 +00:00
|
|
|
if os.path.exists(filepath_jsonl):
|
|
|
|
os.truncate(filepath_jsonl, 0)
|
2023-03-10 17:11:10 +00:00
|
|
|
|
2023-01-12 18:12:50 +00:00
|
|
|
i = 0
|
2023-01-12 16:13:04 +00:00
|
|
|
for input_pair in input_items:
|
|
|
|
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
2023-03-09 19:44:39 +00:00
|
|
|
prediction_mask_argmax = tf.argmax(prediction_mask, axis=2)
|
2023-01-12 16:13:04 +00:00
|
|
|
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
2023-03-09 18:54:28 +00:00
|
|
|
prediction_colormap = decode_segmentation_masks(prediction_mask_argmax, colormap, 2)
|
2023-01-12 16:13:04 +00:00
|
|
|
|
2023-01-12 19:20:22 +00:00
|
|
|
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
|
2023-01-10 19:19:30 +00:00
|
|
|
|
2023-01-05 19:17:44 +00:00
|
|
|
plot_samples_matplotlib(
|
2023-01-12 18:21:20 +00:00
|
|
|
filepath.replace("$$", str(i)),
|
2023-01-11 17:39:14 +00:00
|
|
|
[
|
|
|
|
# input_tensor,
|
2023-03-09 19:13:25 +00:00
|
|
|
tf.math.reduce_max(input_pair[0][:,:,:-1], axis=-1), # rainfall only
|
|
|
|
input_pair[0][:,:,-1], # heightmap
|
2023-03-09 18:54:28 +00:00
|
|
|
input_pair[1], #label_colourmap,
|
|
|
|
prediction_mask[:,:,1],
|
2023-01-11 17:39:14 +00:00
|
|
|
prediction_colormap
|
2023-01-12 18:43:48 +00:00
|
|
|
]
|
2023-01-05 19:17:44 +00:00
|
|
|
)
|
2023-03-10 17:07:44 +00:00
|
|
|
|
|
|
|
save_samples(
|
2023-03-10 17:11:10 +00:00
|
|
|
filepath_jsonl,
|
2023-03-10 17:40:16 +00:00
|
|
|
prediction_mask.numpy().tolist()
|
2023-03-10 17:07:44 +00:00
|
|
|
)
|
2023-01-12 18:12:50 +00:00
|
|
|
i += 1
|
2022-12-15 19:33:14 +00:00
|
|
|
|
2023-01-12 16:13:04 +00:00
|
|
|
def get_from_batched(dataset, count):
|
2022-12-16 19:52:59 +00:00
|
|
|
result = []
|
|
|
|
for batched in dataset:
|
2023-01-12 16:13:04 +00:00
|
|
|
items_input = tf.unstack(batched[0], axis=0)
|
|
|
|
items_label = tf.unstack(batched[1], axis=0)
|
|
|
|
for item in zip(items_input, items_label):
|
2022-12-16 19:52:59 +00:00
|
|
|
result.append(item)
|
|
|
|
if len(result) >= count:
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
plot_predictions(
|
2023-01-12 18:12:50 +00:00
|
|
|
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
|
2023-01-12 18:54:39 +00:00
|
|
|
get_from_batched(dataset_train, PREDICT_COUNT),
|
2022-12-16 19:52:59 +00:00
|
|
|
colormap,
|
|
|
|
model=model
|
|
|
|
)
|
2023-06-16 17:23:40 +00:00
|
|
|
if not PREDICT_AS_ONE:
|
|
|
|
plot_predictions(
|
|
|
|
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
|
|
|
|
get_from_batched(dataset_validate, PREDICT_COUNT),
|
|
|
|
colormap,
|
|
|
|
model=model
|
|
|
|
)
|
2024-08-29 18:33:40 +00:00
|
|
|
if dataset_test is not None:
|
|
|
|
plot_predictions(
|
|
|
|
os.path.join(DIR_OUTPUT, "predict_test_$$.png"),
|
|
|
|
get_from_batched(dataset_test, PREDICT_COUNT),
|
|
|
|
colormap,
|
|
|
|
model=model
|
|
|
|
)
|
2023-01-16 17:30:20 +00:00
|
|
|
|
|
|
|
logger.info(f"Complete at {str(datetime.now().isoformat())}, elapsed {str((datetime.now() - time_start).total_seconds())} seconds")
|