mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
Add (untested) mono rainfall → water depth model
* sighs * Unfortunately I can't seem to get contrastive learning to work.....
This commit is contained in:
parent
ce194d9227
commit
3313f77c88
5 changed files with 260 additions and 3 deletions
100
aimodel/src/lib/ai/RainfallWaterMono.py
Normal file
100
aimodel/src/lib/ai/RainfallWaterMono.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from ..dataset.batched_iterator import batched_iterator
|
||||||
|
|
||||||
|
from ..io.find_paramsjson import find_paramsjson
|
||||||
|
from ..io.readfile import readfile
|
||||||
|
from ..io.writefile import writefile
|
||||||
|
|
||||||
|
from .model_rainfallwater_segmentation import model_rainfallwater_segmentation
|
||||||
|
from .helpers import make_callbacks
|
||||||
|
from .helpers import summarywriter
|
||||||
|
from .components.LayerConvNeXtGamma import LayerConvNeXtGamma
|
||||||
|
from .components.LayerStack2Image import LayerStack2Image
|
||||||
|
from .helpers.summarywriter import summarywriter
|
||||||
|
|
||||||
|
class RainfallWaterMono(object):
|
||||||
|
def __init__(self, dir_output=None, filepath_checkpoint=None, epochs=50, batch_size=64, **kwargs):
|
||||||
|
super(RainfallWaterMono, self).__init__()
|
||||||
|
|
||||||
|
self.dir_output = dir_output
|
||||||
|
self.epochs = epochs
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
|
||||||
|
if filepath_checkpoint == None:
|
||||||
|
if self.dir_output == None:
|
||||||
|
raise Exception("Error: dir_output was not specified, and since no checkpoint was loaded training mode is activated.")
|
||||||
|
if not os.path.exists(self.dir_output):
|
||||||
|
os.mkdir(self.dir_output)
|
||||||
|
|
||||||
|
self.filepath_summary = os.path.join(self.dir_output, "summary.txt")
|
||||||
|
|
||||||
|
writefile(self.filepath_summary, "") # Empty the file ahead of time
|
||||||
|
self.make_model()
|
||||||
|
|
||||||
|
summarywriter(self.model, self.filepath_summary, append=True)
|
||||||
|
writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config()))
|
||||||
|
else:
|
||||||
|
self.load_model(filepath_checkpoint)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"epochs": self.epochs,
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
**self.kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_checkpoint(filepath_checkpoint, **hyperparams):
|
||||||
|
logger.info(f"Loading from checkpoint: {filepath_checkpoint}")
|
||||||
|
return RainfallWaterMono(filepath_checkpoint=filepath_checkpoint, **hyperparams)
|
||||||
|
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
self.model = model_rainfallwater_segmentation(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
**self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self, filepath_checkpoint):
|
||||||
|
"""
|
||||||
|
Loads a saved model from the given filename.
|
||||||
|
filepath_checkpoint (string): The filepath to load the saved model from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.model = tf.keras.models.load_model(filepath_checkpoint, custom_objects={
|
||||||
|
"LayerConvNeXtGamma": LayerConvNeXtGamma,
|
||||||
|
"LayerStack2Image": LayerStack2Image
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train(self, dataset_train, dataset_validate):
|
||||||
|
return self.model.fit(
|
||||||
|
dataset_train,
|
||||||
|
validation_data=dataset_validate,
|
||||||
|
epochs=self.epochs,
|
||||||
|
callbacks=make_callbacks(self.dir_output, self.model),
|
||||||
|
# steps_per_epoch=10 # For testing
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(self, rainfall_embed):
|
||||||
|
rainfall = self.model(rainfall_embed, training=False) # (rainfall_embed, water)
|
||||||
|
|
||||||
|
for step in tf.unstack(rainfall, axis=0):
|
||||||
|
yield step
|
||||||
|
|
||||||
|
|
||||||
|
# def embed_rainfall(self, dataset):
|
||||||
|
# result = []
|
||||||
|
# for batch in dataset:
|
||||||
|
# result_batch = self.model_predict(batch)
|
||||||
|
# result.extend(tf.unstack(result_batch, axis=0))
|
||||||
|
# return result
|
73
aimodel/src/lib/ai/model_rainfallwater_mono.py
Normal file
73
aimodel/src/lib/ai/model_rainfallwater_mono.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from .components.convnext import make_convnext
|
||||||
|
from .components.convnext_inverse import do_convnext_inverse
|
||||||
|
from .components.LayerStack2Image import LayerStack2Image
|
||||||
|
|
||||||
|
def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2):
|
||||||
|
"""Makes a new rainfall / waterdepth mono model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): A dictionary of metadata about the dataset to use to build the model with.
|
||||||
|
shape_water_out (int[]): The width and height (in that order) that should dictate the output shape of the segmentation head. CURRENTLY NOT USED.
|
||||||
|
model_arch (str, optional): The architecture code for the underlying (inverted) ConvNeXt model. Defaults to "convnext_i_xtiny".
|
||||||
|
batch_size (int, optional): The batch size. Reduce to save memory. Defaults to 64.
|
||||||
|
water_bins (int, optional): The number of classes that the water depth output oft he segmentation head should be binned into. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.keras.Model: The new model, freshly compiled for your convenience! :D
|
||||||
|
"""
|
||||||
|
rainfall_channels, rainfall_width, rainfall_height = metadata["rainfallradar"] # shape = [channels, width, height]
|
||||||
|
|
||||||
|
out_water_width, out_water_height = shape_water_out
|
||||||
|
|
||||||
|
layer_input = tf.keras.layers.Input(
|
||||||
|
shape=(rainfall_width, rainfall_height, rainfall_channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ENCODER
|
||||||
|
layer_next = make_convnext(
|
||||||
|
input_shape = (rainfall_width, rainfall_height, rainfall_channels),
|
||||||
|
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
|
||||||
|
num_classes = feature_dim, # size of the feature dimension, see the line above this one
|
||||||
|
arch_name = model_arch_enc
|
||||||
|
)(layer_input)
|
||||||
|
|
||||||
|
|
||||||
|
# BOTTLENECK
|
||||||
|
layer_next = tf.keras.layers.Dense(name="cns.stage.bottleneck.dense2", units=feature_dim)(layer_input)
|
||||||
|
layer_next = tf.keras.layers.Activation(name="cns.stage.bottleneck.gelu2", activation="gelu")(layer_next)
|
||||||
|
layer_next = tf.keras.layers.LayerNormalization(name="cns.stage.bottleneck.norm2", epsilon=1e-6)(layer_next)
|
||||||
|
layer_next = tf.keras.layers.Dropout(name="cns.stage.bottleneck.dropout", rate=0.1)(layer_next)
|
||||||
|
|
||||||
|
# DECODER
|
||||||
|
layer_next = LayerStack2Image(target_width=4, target_height=4)(layer_next)
|
||||||
|
# layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next)
|
||||||
|
|
||||||
|
layer_next = tf.keras.layers.Dense(name="cns.stage.begin.dense2", units=feature_dim)(layer_next)
|
||||||
|
layer_next = tf.keras.layers.Activation(name="cns.stage_begin.relu2", activation="gelu")(layer_next)
|
||||||
|
layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm2", epsilon=1e-6)(layer_next)
|
||||||
|
|
||||||
|
layer_next = do_convnext_inverse(layer_next, arch_name=model_arch_dec)
|
||||||
|
|
||||||
|
# TODO: An attention layer here instead of a dense layer, with a skip connection perhaps?
|
||||||
|
logger.warning("Warning: TODO implement attention from https://ieeexplore.ieee.org/document/9076883")
|
||||||
|
layer_next = tf.keras.layers.Dense(32, activation="gelu")(layer_next)
|
||||||
|
layer_next = tf.keras.layers.Conv2D(water_bins, activation="gelu", kernel_size=1, padding="same")(layer_next)
|
||||||
|
layer_next = tf.keras.layers.Softmax(axis=-1)(layer_next)
|
||||||
|
|
||||||
|
model = tf.keras.Model(
|
||||||
|
inputs = layer_input,
|
||||||
|
outputs = layer_next
|
||||||
|
)
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer="Adam",
|
||||||
|
loss=tf.keras.losses.CategoricalCrossentropy(),
|
||||||
|
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
|
@ -78,7 +78,7 @@ def get_filepaths(dirpath_input, do_shuffle=True):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5, dummy_label=True):
|
||||||
filepaths = get_filepaths(dirpath_input)
|
filepaths = get_filepaths(dirpath_input)
|
||||||
filepaths_count = len(filepaths)
|
filepaths_count = len(filepaths)
|
||||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||||
|
@ -88,8 +88,8 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
||||||
|
|
||||||
metadata = read_metadata(dirpath_input)
|
metadata = read_metadata(dirpath_input)
|
||||||
|
|
||||||
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, dummy_label=dummy_label)
|
||||||
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, dummy_label=dummy_label)
|
||||||
|
|
||||||
return dataset_train, dataset_validate #, filepaths
|
return dataset_train, dataset_validate #, filepaths
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ Available subcommands:
|
||||||
pretrain-plot Plot using embeddings predicted using pretrain-predict.
|
pretrain-plot Plot using embeddings predicted using pretrain-predict.
|
||||||
train Train an image segmentation head on the output of pretrain-predict. YOU MUST TRAIN A CONTRASTIVE LEARNING MODEL FIRST.
|
train Train an image segmentation head on the output of pretrain-predict. YOU MUST TRAIN A CONTRASTIVE LEARNING MODEL FIRST.
|
||||||
train-predict Make predictions using a model trained through the train subcommand.
|
train-predict Make predictions using a model trained through the train subcommand.
|
||||||
|
train-mono Train a mono rainfall → water depth model.
|
||||||
|
|
||||||
For more information, do src/index.py <subcommand> --help.
|
For more information, do src/index.py <subcommand> --help.
|
||||||
""")
|
""")
|
||||||
|
|
83
aimodel/src/subcommands/train_mono.py
Normal file
83
aimodel/src/subcommands/train_mono.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from asyncio.log import logger
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from lib.ai.RainfallWaterMono import RainfallWaterMono
|
||||||
|
from lib.dataset.dataset import dataset
|
||||||
|
from lib.dataset.read_metadata import read_metadata
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Train an mono rainfall-water model on a directory of .tfrecord.gz rainfall+waterdepth_label files.")
|
||||||
|
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
||||||
|
parser.add_argument("--input", "-i", help="Path to input directory containing the .tfrecord.gz files to pretrain with", required=True)
|
||||||
|
parser.add_argument("--output", "-o", help="Path to output directory to write output to (will be automatically created if it doesn't exist)", required=True)
|
||||||
|
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
|
||||||
|
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||||
|
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
|
||||||
|
parser.add_argument("--water-threshold", help="The threshold at which a water cell should be considered water. Water depth values lower than this will be set to 0 (no water). Value unit is metres [default: 0.1].", type=int)
|
||||||
|
parser.add_argument("--bottleneck", help="The size of the bottleneck [default: 512].", type=int)
|
||||||
|
parser.add_argument("--arch-enc", help="Next of the underlying encoder convnext model to use [default: convnext_xtiny].")
|
||||||
|
parser.add_argument("--arch-dec", help="Next of the underlying decoder convnext model to use [default: convnext_i_xtiny].")
|
||||||
|
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def run(args):
|
||||||
|
if (not hasattr(args, "water_size")) or args.water_size == None:
|
||||||
|
args.water_size = 100
|
||||||
|
if (not hasattr(args, "batch_size")) or args.batch_size == None:
|
||||||
|
args.batch_size = 64
|
||||||
|
if (not hasattr(args, "feature_dim")) or args.feature_dim == None:
|
||||||
|
args.feature_dim = 512
|
||||||
|
if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None:
|
||||||
|
args.read_multiplier = 1.5
|
||||||
|
if (not hasattr(args, "water_threshold")) or args.water_threshold == None:
|
||||||
|
args.water_threshold = 1.5
|
||||||
|
if (not hasattr(args, "water_size")) or args.water_size == None:
|
||||||
|
args.water_size = 1.5
|
||||||
|
if (not hasattr(args, "bottleneck")) or args.bottleneck == None:
|
||||||
|
args.bottleneck = 512
|
||||||
|
if (not hasattr(args, "arch_enc")) or args.arch_enc == None:
|
||||||
|
args.arch_enc = "convnext_xtiny"
|
||||||
|
if (not hasattr(args, "arch_dec")) or args.arch_dec == None:
|
||||||
|
args.arch_dec = "convnext_i_xtiny"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Validate args here.
|
||||||
|
|
||||||
|
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
dataset_train, dataset_validate = dataset(
|
||||||
|
dirpath_input=args.input,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
water_threshold=args.water_threshold,
|
||||||
|
shape_water_desired=[args.water_size, args.water_size],
|
||||||
|
dummy_label=False
|
||||||
|
)
|
||||||
|
dataset_metadata = read_metadata(args.input)
|
||||||
|
|
||||||
|
# for (items, label) in dataset_train:
|
||||||
|
# print("ITEMS", len(items), [ item.shape for item in items ])
|
||||||
|
# print("LABEL", label.shape)
|
||||||
|
# print("ITEMS DONE")
|
||||||
|
# exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
ai = RainfallWaterMono(
|
||||||
|
dir_output=args.output,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
|
||||||
|
feature_dim=args.bottleneck,
|
||||||
|
model_arch_enc=args.arch_enc,
|
||||||
|
model_arch_dec=args.arch_dec,
|
||||||
|
|
||||||
|
metadata = read_metadata(args.input),
|
||||||
|
shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
|
||||||
|
)
|
||||||
|
|
||||||
|
ai.train(dataset_train, dataset_validate)
|
||||||
|
|
Loading…
Reference in a new issue