2022-12-15 19:33:14 +00:00
#!/usr/bin/env python3
# @source https://keras.io/examples/vision/deeplabv3_plus/
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
from datetime import datetime
from loguru import logger
2023-03-09 19:54:27 +00:00
2022-12-15 19:33:14 +00:00
from lib . ai . helpers . summarywriter import summarywriter
2023-01-09 18:03:23 +00:00
from lib . ai . components . CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
2022-12-15 19:33:14 +00:00
import os
2023-03-10 17:11:10 +00:00
import io
2023-03-09 19:54:27 +00:00
import math
2023-03-10 17:31:03 +00:00
import json
2024-11-01 20:40:55 +00:00
# import cv2 # optional import below in get_overlay
2022-12-15 19:33:14 +00:00
import numpy as np
from glob import glob
from scipy . io import loadmat
import matplotlib . pyplot as plt
import tensorflow as tf
2024-08-29 15:43:29 +00:00
import lib . primitives . env
2023-06-16 17:23:40 +00:00
from lib . dataset . dataset_mono import dataset_mono , dataset_mono_predict
2023-01-13 17:58:00 +00:00
from lib . ai . components . LossCrossEntropyDice import LossCrossEntropyDice
2023-03-03 22:04:21 +00:00
from lib . ai . components . MetricDice import metric_dice_coefficient as dice_coefficient
2023-03-03 22:44:49 +00:00
from lib . ai . components . MetricSensitivity import make_sensitivity as sensitivity
2023-03-03 20:37:22 +00:00
from lib . ai . components . MetricSpecificity import specificity
2023-03-03 22:44:49 +00:00
from lib . ai . components . MetricMeanIoU import make_one_hot_mean_iou as mean_iou
tvt: implement CallbackExtraValidation, which allows for a third split
it should tie into Tensorflow's logging just fine so long as it's the first callback in the queue.
***** TEST SCRIPT *****
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
X = np.random.random((100, 10))
y = np.random.random((100, 1))
split = 80
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(10)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(10)
history = model.fit(train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=[
CallbackExtraValidation({
"test": val_dataset
}, verbose=0),
tf.keras.callbacks.CSVLogger("/dev/stdout", separator="\t")
],
verbose=0
)
print(f"DEBUG history {history}")
2024-08-30 17:07:17 +00:00
from lib . ai . components . CallbackExtraValidation import CallbackExtraValidation
2023-01-05 18:26:33 +00:00
2023-01-16 17:30:20 +00:00
time_start = datetime . now ( )
logger . info ( f " Starting at { str ( datetime . now ( ) . isoformat ( ) ) } " )
2023-03-01 16:47:36 +00:00
# ███████ ███ ██ ██ ██ ██ ██████ ██████ ███ ██ ███ ███ ███████ ███ ██ ████████
# ██ ████ ██ ██ ██ ██ ██ ██ ██ ██ ████ ██ ████ ████ ██ ████ ██ ██
# █████ ██ ██ ██ ██ ██ ██ ██████ ██ ██ ██ ██ ██ ██ ████ ██ █████ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
2024-08-29 15:43:29 +00:00
IMAGE_SIZE = env . read ( " IMAGE_SIZE " , int , 128 ) # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = env . read ( " BATCH_SIZE " , int , 64 )
2022-12-15 19:33:14 +00:00
NUM_CLASSES = 2
2024-08-29 15:43:29 +00:00
DIR_RAINFALLWATER = env . read ( " DIR_RAINFALLWATER " , str )
PATH_HEIGHTMAP = env . read ( " PATH_HEIGHTMAP " , str )
PATH_COLOURMAP = env . read ( " PATH_COLOURMAP " , str )
PARALLEL_READS = env . read ( " PARALLEL_READS " , float , 1.5 )
STEPS_PER_EPOCH = env . read ( " STEPS_PER_EPOCH " , int , None )
REMOVE_ISOLATED_PIXELS = env . read ( " NO_REMOVE_ISOLATED_PIXELS " , bool , True )
EPOCHS = env . read ( " EPOCHS " , int , 50 )
LOSS = env . read ( " LOSS " , str , " cross-entropy-dice " ) # other possible values: cross-entropy
DICE_LOG_COSH = env . read ( " DICE_LOG_COSH " , bool , False )
LEARNING_RATE = env . read ( " LEARNING_RATE " , float , 0.001 )
WATER_THRESHOLD = env . read ( " WATER_THRESHOLD " , float , 0.1 )
UPSAMPLE = env . read ( " UPSAMPLE " , int , 2 )
SPLIT_VALIDATE = env . read ( " SPLIT_VALIDATE " , float , 0.2 )
SPLIT_TEST = env . read ( " SPLIT_TEST " , float , 0 )
2024-08-30 17:50:26 +00:00
# NOTE: RANDSEED is declared and handled in src/lib/dataset/primitives/shuffle.py
2024-08-29 15:43:29 +00:00
STEPS_PER_EXECUTION = env . read ( " STEPS_PER_EXECUTION " , int , 1 )
JIT_COMPILE = env . read ( " JIT_COMPILE " , bool , False )
DIR_OUTPUT = env . read ( " DIR_OUTPUT " , str , f " output/ { datetime . utcnow ( ) . date ( ) . isoformat ( ) } _deeplabv3plus_rainfall_TEST " )
PATH_CHECKPOINT = env . read ( " PATH_CHECKPOINT " , str , None )
PREDICT_COUNT = env . read ( " PREDICT_COUNT " , int , 25 )
PREDICT_AS_ONE = env . read ( " PREDICT_AS_ONE " , bool , False )
2024-08-29 18:33:40 +00:00
2023-03-01 16:47:36 +00:00
# ~~~
2024-08-29 15:43:29 +00:00
env . val_dir_exists ( os . path . join ( DIR_OUTPUT , " checkpoints " ) , create = True )
2022-12-15 19:33:14 +00:00
2023-03-01 16:47:36 +00:00
# ~~~
2023-01-13 17:58:00 +00:00
2023-03-01 16:47:36 +00:00
logger . info ( " DeepLabV3+ rainfall radar TEST " )
2024-08-29 15:43:29 +00:00
env . print_all ( False )
# for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
# logger.info(f"> {env_name} {str(globals()[env_name])}")
2023-01-13 17:58:00 +00:00
2022-12-15 19:33:14 +00:00
2023-03-01 16:47:36 +00:00
# ██████ █████ ████████ █████ ███████ ███████ ████████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ███████ ██ ███████ ███████ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
2022-12-15 19:33:14 +00:00
2023-06-16 17:23:40 +00:00
if not PREDICT_AS_ONE :
2024-08-29 18:33:40 +00:00
dataset_train , dataset_validate , dataset_test = dataset_mono (
2023-06-16 17:23:40 +00:00
dirpath_input = DIR_RAINFALLWATER ,
batch_size = BATCH_SIZE ,
water_threshold = WATER_THRESHOLD ,
rainfall_scale_up = 2 , # done BEFORE cropping to the below size
output_size = IMAGE_SIZE ,
input_size = " same " ,
filepath_heightmap = PATH_HEIGHTMAP ,
2023-11-30 16:33:22 +00:00
do_remove_isolated_pixels = REMOVE_ISOLATED_PIXELS ,
2024-08-29 18:33:40 +00:00
parallel_reads_multiplier = PARALLEL_READS ,
percentage_validate = SPLIT_VALIDATE ,
percentage_test = SPLIT_TESTs
2023-06-16 17:23:40 +00:00
)
2022-12-15 19:33:14 +00:00
2023-06-16 17:23:40 +00:00
logger . info ( " Train Dataset: " , dataset_train )
logger . info ( " Validation Dataset: " , dataset_validate )
2024-08-29 18:33:40 +00:00
logger . info ( " Test Dataset: " , dataset_test )
2023-06-16 17:23:40 +00:00
else :
dataset_train = dataset_mono_predict (
dirpath_input = DIR_RAINFALLWATER ,
batch_size = BATCH_SIZE ,
water_threshold = WATER_THRESHOLD ,
rainfall_scale_up = 2 , # done BEFORE cropping to the below size
output_size = IMAGE_SIZE ,
input_size = " same " ,
filepath_heightmap = PATH_HEIGHTMAP ,
do_remove_isolated_pixels = REMOVE_ISOLATED_PIXELS
)
logger . info ( " Dataset AS_ONE: " , dataset_train )
2022-12-15 19:33:14 +00:00
# ███ ███ ██████ ██████ ███████ ██
# ████ ████ ██ ██ ██ ██ ██ ██
# ██ ████ ██ ██ ██ ██ ██ █████ ██
# ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████
2023-01-11 17:20:19 +00:00
if PATH_CHECKPOINT is None :
def convolution_block (
block_input ,
num_filters = 256 ,
kernel_size = 3 ,
dilation_rate = 1 ,
padding = " same " ,
use_bias = False ,
) :
x = tf . keras . layers . Conv2D (
num_filters ,
kernel_size = kernel_size ,
dilation_rate = dilation_rate ,
padding = " same " ,
use_bias = use_bias ,
kernel_initializer = tf . keras . initializers . HeNormal ( ) ,
) ( block_input )
x = tf . keras . layers . BatchNormalization ( ) ( x )
return tf . nn . relu ( x )
def DilatedSpatialPyramidPooling ( dspp_input ) :
dims = dspp_input . shape
x = tf . keras . layers . AveragePooling2D ( pool_size = ( dims [ - 3 ] , dims [ - 2 ] ) ) ( dspp_input )
x = convolution_block ( x , kernel_size = 1 , use_bias = True )
out_pool = tf . keras . layers . UpSampling2D (
size = ( dims [ - 3 ] / / x . shape [ 1 ] , dims [ - 2 ] / / x . shape [ 2 ] ) , interpolation = " bilinear " ,
) ( x )
2023-02-23 16:47:00 +00:00
2023-01-11 17:20:19 +00:00
out_1 = convolution_block ( dspp_input , kernel_size = 1 , dilation_rate = 1 )
out_6 = convolution_block ( dspp_input , kernel_size = 3 , dilation_rate = 6 )
out_12 = convolution_block ( dspp_input , kernel_size = 3 , dilation_rate = 12 )
out_18 = convolution_block ( dspp_input , kernel_size = 3 , dilation_rate = 18 )
2023-02-23 16:47:00 +00:00
2023-01-11 17:20:19 +00:00
x = tf . keras . layers . Concatenate ( axis = - 1 ) ( [ out_pool , out_1 , out_6 , out_12 , out_18 ] )
output = convolution_block ( x , kernel_size = 1 )
return output
2023-02-23 16:47:00 +00:00
2023-05-04 16:40:16 +00:00
def DeeplabV3Plus ( image_size , num_classes , num_channels = 3 , backbone = " resnet " , upsample = 2 ) :
2023-01-11 17:20:19 +00:00
model_input = tf . keras . Input ( shape = ( image_size , image_size , num_channels ) )
2023-05-04 16:40:16 +00:00
if upsample > 1 :
logger . info ( f " [DeepLabV3+] Upsample enabled @ { upsample } x " )
x = tf . keras . layers . UpSampling2D ( size = 2 ) ( model_input )
else :
logger . info ( f " [DeepLabV3+] Upsample disabled " )
x = model_input
2023-03-14 21:51:41 +00:00
match backbone :
case " resnet " :
backbone = tf . keras . applications . ResNet50 (
weights = " imagenet " if num_channels == 3 else None ,
include_top = False , input_tensor = x
)
case _ :
raise Exception ( f " Error: Unknown backbone { backbone } " )
x = backbone . get_layer ( " conv4_block6_2_relu " ) . output
2023-01-11 17:20:19 +00:00
x = DilatedSpatialPyramidPooling ( x )
2023-05-04 18:54:51 +00:00
2023-05-04 18:57:02 +00:00
factor = 4 if upsample == 2 else 8 # else: upsample == 1. other values are not supported yet because maths
2023-01-11 17:20:19 +00:00
input_a = tf . keras . layers . UpSampling2D (
2023-05-04 18:57:02 +00:00
size = ( image_size / / factor / / x . shape [ 1 ] * 2 , image_size / / factor / / x . shape [ 2 ] * 2 ) , # <--- UPSAMPLE after pyramid
2023-01-11 17:20:19 +00:00
interpolation = " bilinear " ,
) ( x )
2023-03-14 21:51:41 +00:00
input_b = backbone . get_layer ( " conv2_block3_2_relu " ) . output
2023-01-11 17:20:19 +00:00
input_b = convolution_block ( input_b , num_filters = 48 , kernel_size = 1 )
x = tf . keras . layers . Concatenate ( axis = - 1 ) ( [ input_a , input_b ] )
x = convolution_block ( x )
x = convolution_block ( x )
x = tf . keras . layers . UpSampling2D (
2023-02-23 16:47:00 +00:00
size = ( image_size / / x . shape [ 1 ] , image_size / / x . shape [ 2 ] ) , # <--- UPSAMPLE at end
2023-01-11 17:20:19 +00:00
interpolation = " bilinear " ,
) ( x )
model_output = tf . keras . layers . Conv2D ( num_classes , kernel_size = ( 1 , 1 ) , padding = " same " ) ( x )
return tf . keras . Model ( inputs = model_input , outputs = model_output )
2023-05-04 16:40:16 +00:00
model = DeeplabV3Plus (
image_size = IMAGE_SIZE ,
num_classes = NUM_CLASSES ,
upsample = UPSAMPLE ,
num_channels = 8
)
2023-01-11 17:20:19 +00:00
summarywriter ( model , os . path . join ( DIR_OUTPUT , " summary.txt " ) )
2023-01-11 17:26:57 +00:00
else :
2023-03-01 17:19:10 +00:00
model = tf . keras . models . load_model ( PATH_CHECKPOINT , custom_objects = {
# Tell Tensorflow about our custom layers so that it can deserialise models that use them
2023-03-03 19:34:55 +00:00
" LossCrossEntropyDice " : LossCrossEntropyDice ,
2023-03-09 19:26:57 +00:00
" metric_dice_coefficient " : dice_coefficient ,
2023-03-09 19:43:35 +00:00
" sensitivity " : sensitivity ,
2023-03-09 19:26:57 +00:00
" specificity " : specificity ,
2023-03-09 19:34:45 +00:00
" one_hot_mean_iou " : mean_iou
2023-03-01 17:19:10 +00:00
} )
2022-12-15 19:33:14 +00:00
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
2023-03-22 17:41:34 +00:00
def plot_metric ( train , val , name , dir_output ) :
plt . plot ( train , label = f " train_ { name } " )
plt . plot ( val , label = f " val_ { name } " )
plt . title ( name )
plt . xlabel ( " epoch " )
plt . ylabel ( name )
plt . savefig ( os . path . join ( dir_output , f " { name } .png " ) )
plt . close ( )
2023-01-11 17:20:19 +00:00
if PATH_CHECKPOINT is None :
2023-01-13 17:58:00 +00:00
loss_fn = None
if LOSS == " cross-entropy-dice " :
2023-03-10 20:24:13 +00:00
loss_fn = LossCrossEntropyDice ( log_cosh = DICE_LOG_COSH )
2023-01-13 17:58:00 +00:00
elif LOSS == " cross-entropy " :
2023-01-13 18:47:29 +00:00
loss_fn = tf . keras . losses . SparseCategoricalCrossentropy ( from_logits = True )
2023-01-13 17:58:00 +00:00
else :
raise Exception ( f " Error: Unknown loss function ' { LOSS } ' (possible values: cross-entropy, cross-entropy-dice). " )
2023-01-11 17:20:19 +00:00
model . compile (
2023-01-13 18:29:39 +00:00
optimizer = tf . keras . optimizers . Adam ( learning_rate = LEARNING_RATE ) ,
2023-01-13 17:58:00 +00:00
loss = loss_fn ,
2023-03-03 19:34:55 +00:00
metrics = [
" accuracy " ,
2023-03-03 22:04:21 +00:00
dice_coefficient ,
2023-03-03 22:44:49 +00:00
mean_iou ( ) ,
sensitivity ( ) , # How many true positives were accurately predicted
2023-03-03 20:37:22 +00:00
specificity # How many true negatives were accurately predicted?
2023-03-03 19:34:55 +00:00
# TODO: Add IoU, F1, Precision, Recall, here.
] ,
2023-05-04 17:22:18 +00:00
steps_per_execution = STEPS_PER_EXECUTION ,
jit_compile = JIT_COMPILE
2023-01-11 17:20:19 +00:00
)
logger . info ( " >>> Beginning training " )
history = model . fit ( dataset_train ,
validation_data = dataset_validate ,
2024-08-29 18:33:40 +00:00
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
2023-01-12 18:54:39 +00:00
epochs = EPOCHS ,
2023-01-11 17:20:19 +00:00
callbacks = [
tvt: implement CallbackExtraValidation, which allows for a third split
it should tie into Tensorflow's logging just fine so long as it's the first callback in the queue.
***** TEST SCRIPT *****
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
X = np.random.random((100, 10))
y = np.random.random((100, 1))
split = 80
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(10)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(10)
history = model.fit(train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=[
CallbackExtraValidation({
"test": val_dataset
}, verbose=0),
tf.keras.callbacks.CSVLogger("/dev/stdout", separator="\t")
],
verbose=0
)
print(f"DEBUG history {history}")
2024-08-30 17:07:17 +00:00
CallbackExtraValidation ( model , {
" test " : dataset_test # Can be None because it handles that
} ) ,
2023-01-11 17:20:19 +00:00
tf . keras . callbacks . CSVLogger (
filename = os . path . join ( DIR_OUTPUT , " metrics.tsv " ) ,
separator = " \t "
2023-01-09 18:03:23 +00:00
) ,
2023-01-11 17:20:19 +00:00
CallbackCustomModelCheckpoint (
model_to_checkpoint = model ,
filepath = os . path . join (
DIR_OUTPUT ,
2023-01-11 17:28:13 +00:00
" checkpoints " ,
2023-01-11 17:20:19 +00:00
" checkpoint_e {epoch:d} _loss {loss:.3f} .hdf5 "
) ,
monitor = " loss "
) ,
] ,
steps_per_epoch = STEPS_PER_EPOCH ,
)
logger . info ( " >>> Training complete " )
logger . info ( " >>> Plotting graphs " )
2023-03-22 17:41:34 +00:00
2023-05-07 18:00:02 +00:00
plot_metric ( history . history [ " loss " ] , history . history [ " val_loss " ] , " loss " , DIR_OUTPUT )
2023-05-19 21:00:23 +00:00
plot_metric ( history . history [ " accuracy " ] , history . history [ " val_accuracy " ] , " accuracy " , DIR_OUTPUT )
2023-05-07 18:00:02 +00:00
plot_metric ( history . history [ " metric_dice_coefficient " ] , history . history [ " val_metric_dice_coefficient " ] , " dice " , DIR_OUTPUT )
plot_metric ( history . history [ " one_hot_mean_iou " ] , history . history [ " val_one_hot_mean_iou " ] , " mean iou " , DIR_OUTPUT )
plot_metric ( history . history [ " sensitivity " ] , history . history [ " val_sensitivity " ] , " sensitivity " , DIR_OUTPUT )
plot_metric ( history . history [ " specificity " ] , history . history [ " val_specificity " ] , " specificity " , DIR_OUTPUT )
2023-03-22 17:41:34 +00:00
2022-12-15 19:33:14 +00:00
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
# ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██
# ██ ██ ██ ██ █████ █████ ██████ █████ ██ ██ ██ ██ █████
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ████ ██ ███████ ██ ██ ███████ ██ ████ ██████ ███████
# Loading the Colormap
colormap = loadmat (
2023-01-05 19:17:44 +00:00
PATH_COLOURMAP
2022-12-15 19:33:14 +00:00
) [ " colormap " ]
colormap = colormap * 100
colormap = colormap . astype ( np . uint8 )
2023-03-09 18:54:28 +00:00
def infer ( model , image_tensor , do_argmax = True ) :
2023-01-05 19:17:44 +00:00
predictions = model . predict ( tf . expand_dims ( ( image_tensor ) , axis = 0 ) )
predictions = tf . squeeze ( predictions )
return predictions
2022-12-15 19:33:14 +00:00
def decode_segmentation_masks ( mask , colormap , n_classes ) :
2023-01-05 19:17:44 +00:00
r = np . zeros_like ( mask ) . astype ( np . uint8 )
g = np . zeros_like ( mask ) . astype ( np . uint8 )
b = np . zeros_like ( mask ) . astype ( np . uint8 )
for l in range ( 0 , n_classes ) :
idx = mask == l
r [ idx ] = colormap [ l , 0 ]
g [ idx ] = colormap [ l , 1 ]
b [ idx ] = colormap [ l , 2 ]
rgb = np . stack ( [ r , g , b ] , axis = 2 )
return rgb
2022-12-15 19:33:14 +00:00
2023-01-05 17:09:09 +00:00
def get_overlay ( image , coloured_mask ) :
2024-11-01 20:40:55 +00:00
global cv2
# import cv2 only when used, since it might not be available and this function isn't currently used (was prob something I dropped halfway through writing 'cause I got distracted)
if not cv2 :
cv2 = __import__ ( " cv2 " )
2023-01-05 19:17:44 +00:00
image = tf . keras . preprocessing . image . array_to_img ( image )
image = np . array ( image ) . astype ( np . uint8 )
overlay = cv2 . addWeighted ( image , 0.35 , coloured_mask , 0.65 , 0 )
return overlay
2022-12-15 19:33:14 +00:00
2023-01-12 18:43:48 +00:00
def plot_samples_matplotlib ( filepath , display_list ) :
2023-03-09 19:54:27 +00:00
plt . figure ( figsize = ( 16 , 8 ) )
2023-01-05 19:17:44 +00:00
for i in range ( len ( display_list ) ) :
2023-03-09 19:54:27 +00:00
plt . subplot ( 2 , math . ceil ( len ( display_list ) / 2 ) , i + 1 )
2023-01-05 19:17:44 +00:00
if display_list [ i ] . shape [ - 1 ] == 3 :
2023-01-12 17:56:59 +00:00
plt . imshow ( tf . keras . preprocessing . image . array_to_img ( display_list [ i ] ) )
2023-01-05 19:17:44 +00:00
else :
2023-01-12 17:56:59 +00:00
plt . imshow ( display_list [ i ] )
2023-03-09 18:54:28 +00:00
plt . colorbar ( )
2023-01-12 18:43:48 +00:00
plt . savefig ( filepath , dpi = 200 )
2022-12-15 19:33:14 +00:00
2023-03-10 17:07:44 +00:00
def save_samples ( filepath , save_list ) :
handle = io . open ( filepath , " a " )
json . dump ( save_list , handle )
handle . write ( " \n " )
handle . close ( )
2022-12-15 19:33:14 +00:00
2022-12-16 19:52:59 +00:00
def plot_predictions ( filepath , input_items , colormap , model ) :
2023-03-10 17:11:10 +00:00
filepath_jsonl = filepath . replace ( " _$$ " , " " ) . replace ( " .png " , " .jsonl " )
2023-03-10 17:14:06 +00:00
if os . path . exists ( filepath_jsonl ) :
os . truncate ( filepath_jsonl , 0 )
2023-03-10 17:11:10 +00:00
2023-01-12 18:12:50 +00:00
i = 0
2023-01-12 16:13:04 +00:00
for input_pair in input_items :
prediction_mask = infer ( image_tensor = input_pair [ 0 ] , model = model )
2023-03-09 19:44:39 +00:00
prediction_mask_argmax = tf . argmax ( prediction_mask , axis = 2 )
2023-01-12 16:13:04 +00:00
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
2023-03-09 18:54:28 +00:00
prediction_colormap = decode_segmentation_masks ( prediction_mask_argmax , colormap , 2 )
2023-01-12 16:13:04 +00:00
2023-01-12 19:20:22 +00:00
# print("DEBUG:plot_predictions INFER", str(prediction_mask.numpy().tolist()).replace("], [", "],\n["))
2023-01-10 19:19:30 +00:00
2023-01-05 19:17:44 +00:00
plot_samples_matplotlib (
2023-01-12 18:21:20 +00:00
filepath . replace ( " $$ " , str ( i ) ) ,
2023-01-11 17:39:14 +00:00
[
# input_tensor,
2023-03-09 19:13:25 +00:00
tf . math . reduce_max ( input_pair [ 0 ] [ : , : , : - 1 ] , axis = - 1 ) , # rainfall only
input_pair [ 0 ] [ : , : , - 1 ] , # heightmap
2023-03-09 18:54:28 +00:00
input_pair [ 1 ] , #label_colourmap,
prediction_mask [ : , : , 1 ] ,
2023-01-11 17:39:14 +00:00
prediction_colormap
2023-01-12 18:43:48 +00:00
]
2023-01-05 19:17:44 +00:00
)
2023-03-10 17:07:44 +00:00
save_samples (
2023-03-10 17:11:10 +00:00
filepath_jsonl ,
2023-03-10 17:40:16 +00:00
prediction_mask . numpy ( ) . tolist ( )
2023-03-10 17:07:44 +00:00
)
2023-01-12 18:12:50 +00:00
i + = 1
2022-12-15 19:33:14 +00:00
2023-01-12 16:13:04 +00:00
def get_from_batched ( dataset , count ) :
2022-12-16 19:52:59 +00:00
result = [ ]
for batched in dataset :
2023-01-12 16:13:04 +00:00
items_input = tf . unstack ( batched [ 0 ] , axis = 0 )
items_label = tf . unstack ( batched [ 1 ] , axis = 0 )
for item in zip ( items_input , items_label ) :
2022-12-16 19:52:59 +00:00
result . append ( item )
if len ( result ) > = count :
return result
plot_predictions (
2023-01-12 18:12:50 +00:00
os . path . join ( DIR_OUTPUT , " predict_train_$$.png " ) ,
2023-01-12 18:54:39 +00:00
get_from_batched ( dataset_train , PREDICT_COUNT ) ,
2022-12-16 19:52:59 +00:00
colormap ,
model = model
)
2023-06-16 17:23:40 +00:00
if not PREDICT_AS_ONE :
plot_predictions (
os . path . join ( DIR_OUTPUT , " predict_validate_$$.png " ) ,
get_from_batched ( dataset_validate , PREDICT_COUNT ) ,
colormap ,
model = model
)
2024-08-29 18:33:40 +00:00
if dataset_test is not None :
plot_predictions (
os . path . join ( DIR_OUTPUT , " predict_test_$$.png " ) ,
get_from_batched ( dataset_test , PREDICT_COUNT ) ,
colormap ,
model = model
)
2023-01-16 17:30:20 +00:00
logger . info ( f " Complete at { str ( datetime . now ( ) . isoformat ( ) ) } , elapsed { str ( ( datetime . now ( ) - time_start ) . total_seconds ( ) ) } seconds " )