mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
encoderonly model: getting there
This commit is contained in:
parent
4b7df39fac
commit
1958c4e6c2
2 changed files with 102 additions and 29 deletions
|
@ -4,8 +4,9 @@ import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from lib.dataset.dataset_encoderonly import dataset_encoderonly
|
||||||
from lib.ai.components.convnext import make_convnext
|
from lib.ai.components.convnext import make_convnext
|
||||||
|
from lib.ai.helpers.summarywriter import summarywriter
|
||||||
|
|
||||||
# ███████ ███ ██ ██ ██
|
# ███████ ███ ██ ██ ██
|
||||||
# ██ ████ ██ ██ ██
|
# ██ ████ ██ ██ ██
|
||||||
|
@ -14,8 +15,27 @@ from lib.ai.components.convnext import make_convnext
|
||||||
# ███████ ██ ████ ████
|
# ███████ ██ ████ ████
|
||||||
|
|
||||||
# TODO: env vars & settings here
|
# TODO: env vars & settings here
|
||||||
|
DIRPATH_INPUT = os.environ["DIRPATH_INPUT"]
|
||||||
|
DIRPATH_OUTPUT = os.environ["DIRPATH_OUTPUT"]
|
||||||
|
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] if "PATH_HEIGHTMAP" in os.environ else None
|
||||||
|
CHANNELS = os.environ["CHANNELS"] if "CHANNELS" in os.environ else 8
|
||||||
|
|
||||||
|
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||||
|
WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ else 33
|
||||||
|
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||||
|
|
||||||
|
logger.info("Encoder-only rainfall radar TEST")
|
||||||
|
logger.info(f"> DIRPATH_INPUT {DIRPATH_INPUT}")
|
||||||
|
logger.info(f"> DIRPATH_OUTPUT {DIRPATH_OUTPUT}")
|
||||||
|
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||||
|
logger.info(f"> CHANNELS {CHANNELS}")
|
||||||
|
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
|
||||||
|
logger.info(f"> WINDOW_SIZE {WINDOW_SIZE}")
|
||||||
|
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.exists(DIRPATH_OUTPUT):
|
||||||
|
os.makedirs(os.path.join(DIRPATH_OUTPUT, "checkpoints"))
|
||||||
|
|
||||||
|
|
||||||
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
||||||
|
@ -24,7 +44,12 @@ from lib.ai.components.convnext import make_convnext
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
|
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
|
||||||
|
|
||||||
|
dataset_train, dataset_validate = dataset_encoderonly(
|
||||||
|
dirpath_input=DIRPATH_INPUT,
|
||||||
|
filepath_heightmap=PATH_HEIGHTMAP,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
windowsize=WINDOW_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ███ ███ ██████ ██████ ███████ ██
|
# ███ ███ ██████ ██████ ███████ ██
|
||||||
|
@ -33,9 +58,42 @@ from lib.ai.components.convnext import make_convnext
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ██ ██ ██████ ██████ ███████ ███████
|
# ██ ██ ██████ ██████ ███████ ███████
|
||||||
|
|
||||||
|
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2):
|
||||||
|
if encoder == "convnext":
|
||||||
|
model = make_convnext(input_shape=(windowsize, windowsize, channels), num_classes=water_bins, **kwargs)
|
||||||
|
elif encoder == "resnet":
|
||||||
|
layer_in = tf.keras.Input(shape=(windowsize, windowsize, channels))
|
||||||
|
layer_next = tf.keras.applications.resnet50.ResNet50(
|
||||||
|
weights=None,
|
||||||
|
include_top=True,
|
||||||
|
classes=water_bins,
|
||||||
|
input_tensor=layer_in,
|
||||||
|
pooling="max",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = tf.keras.Model(
|
||||||
|
inputs = layer_in,
|
||||||
|
outputs = layer_next
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Error: Unknown encoder '{encoder}' (known encoders: convnext, resnet).")
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer="Adam",
|
||||||
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(),
|
||||||
|
metrics = [
|
||||||
|
tf.keras.metrics.SparseCategoricalAccuracy()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
# TODO: Have an
|
|
||||||
|
|
||||||
|
model = make_encoderonly(
|
||||||
|
windowsize=WINDOW_SIZE,
|
||||||
|
channels=CHANNELS
|
||||||
|
)
|
||||||
|
summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt"))
|
||||||
|
|
||||||
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
|
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
|
||||||
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
|
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
|
||||||
|
@ -43,6 +101,28 @@ from lib.ai.components.convnext import make_convnext
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
|
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
|
||||||
|
|
||||||
|
model.fit(dataset_train,
|
||||||
|
validation_data=dataset_validate,
|
||||||
|
epochs=25,
|
||||||
|
|
||||||
|
callbacks=[
|
||||||
|
tf.keras.callbacks.CSVLogger(
|
||||||
|
filename=os.path.join(DIRPATH_OUTPUT, "metrics.tsv"),
|
||||||
|
separator="\t"
|
||||||
|
),
|
||||||
|
CallbackCustomModelCheckpoint(
|
||||||
|
model_to_checkpoint=model,
|
||||||
|
filepath=os.path.join(
|
||||||
|
DIRPATH_OUTPUT,
|
||||||
|
"checkpoints"
|
||||||
|
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
||||||
|
),
|
||||||
|
monitor="loss"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
steps_per_epoch=STEPS_PER_EPOCH,
|
||||||
|
)
|
||||||
|
logger.info(">>> Training complete")
|
||||||
|
|
||||||
|
|
||||||
# ██████ ██████ ███████ ██████ ██ ██████ ████████ ██ ██████ ███ ██
|
# ██████ ██████ ███████ ██████ ██ ██████ ████████ ██ ██████ ███ ██
|
||||||
|
@ -52,7 +132,7 @@ from lib.ai.components.convnext import make_convnext
|
||||||
# ██ ██ ██ ███████ ██████ ██ ██████ ██ ██ ██████ ██ ████
|
# ██ ██ ██ ███████ ██████ ██ ██████ ██ ██ ██████ ██ ████
|
||||||
|
|
||||||
|
|
||||||
|
# TODO FILL THIS IN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,19 +15,15 @@ from .parse_heightmap import parse_heightmap
|
||||||
|
|
||||||
|
|
||||||
# TO PARSE:
|
# TO PARSE:
|
||||||
def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1):
|
def parse_item(metadata, water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1, windowsize=33):
|
||||||
if input_size == "same":
|
if input_size == "same":
|
||||||
input_size = output_size # This is almost always the case with e.g. the DeepLabV3+ model
|
input_size = output_size # This is almost always the case with e.g. the DeepLabV3+ model
|
||||||
|
|
||||||
water_height_source, water_width_source = metadata["waterdepth"]
|
water_height_source, water_width_source = metadata["waterdepth"]
|
||||||
water_offset_x = math.ceil((water_width_source - output_size) / 2)
|
|
||||||
water_offset_y = math.ceil((water_height_source - output_size) / 2)
|
|
||||||
|
|
||||||
rainfall_channels, rainfall_height_source, rainfall_width_source = metadata["rainfallradar"]
|
rainfall_channels, rainfall_height_source, rainfall_width_source = metadata["rainfallradar"]
|
||||||
rainfall_height_source *= rainfall_scale_up
|
rainfall_height_source *= rainfall_scale_up
|
||||||
rainfall_width_source *= rainfall_scale_up
|
rainfall_width_source *= rainfall_scale_up
|
||||||
rainfall_offset_x = math.ceil((rainfall_width_source - input_size) / 2)
|
|
||||||
rainfall_offset_y = math.ceil((rainfall_height_source - input_size) / 2)
|
|
||||||
|
|
||||||
print("DEBUG DATASET:rainfall shape", metadata["rainfallradar"], "/", f"w {rainfall_width_source} h {rainfall_height_source}")
|
print("DEBUG DATASET:rainfall shape", metadata["rainfallradar"], "/", f"w {rainfall_width_source} h {rainfall_height_source}")
|
||||||
print("DEBUG DATASET:water shape", metadata["waterdepth"])
|
print("DEBUG DATASET:water shape", metadata["waterdepth"])
|
||||||
|
@ -35,8 +31,6 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
print("DEBUG DATASET:water_bins", water_bins)
|
print("DEBUG DATASET:water_bins", water_bins)
|
||||||
print("DEBUG DATASET:output_size", output_size)
|
print("DEBUG DATASET:output_size", output_size)
|
||||||
print("DEBUG DATASET:input_size", input_size)
|
print("DEBUG DATASET:input_size", input_size)
|
||||||
print("DEBUG DATASET:water_offset x", water_offset_x, "y", water_offset_y)
|
|
||||||
print("DEBUG DATASET:rainfall_offset x", rainfall_offset_x, "y", rainfall_offset_y)
|
|
||||||
|
|
||||||
if heightmap is not None:
|
if heightmap is not None:
|
||||||
heightmap = tf.expand_dims(heightmap, axis=-1)
|
heightmap = tf.expand_dims(heightmap, axis=-1)
|
||||||
|
@ -72,21 +66,14 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
rainfall = tf.concat([rainfall, heightmap], axis=-1)
|
rainfall = tf.concat([rainfall, heightmap], axis=-1)
|
||||||
if rainfall_scale_up > 1:
|
if rainfall_scale_up > 1:
|
||||||
rainfall = tf.repeat(tf.repeat(rainfall, rainfall_scale_up, axis=0), rainfall_scale_up, axis=1)
|
rainfall = tf.repeat(tf.repeat(rainfall, rainfall_scale_up, axis=0), rainfall_scale_up, axis=1)
|
||||||
if input_size is not None:
|
|
||||||
rainfall = tf.image.crop_to_bounding_box(rainfall,
|
|
||||||
offset_width=rainfall_offset_x,
|
|
||||||
offset_height=rainfall_offset_y,
|
|
||||||
target_width=input_size,
|
|
||||||
target_height=input_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# rainfall = tf.image.resize(rainfall, tf.cast(tf.constant(metadata["rainfallradar"]) / 2, dtype=tf.int32))
|
# rainfall = tf.image.resize(rainfall, tf.cast(tf.constant(metadata["rainfallradar"]) / 2, dtype=tf.int32))
|
||||||
water = tf.expand_dims(water, axis=-1) # [height, width] → [height, width, channels=1]
|
water = tf.expand_dims(water, axis=-1) # [height, width] → [height, width, channels=1]
|
||||||
water = tf.image.crop_to_bounding_box(water,
|
water = tf.image.crop_to_bounding_box(water, # for path generation; we only want the water bit where the patches convolve
|
||||||
offset_width=water_offset_x,
|
offset_width=math.floor(windowsize/2),
|
||||||
offset_height=water_offset_y,
|
offset_height=math.floor(windowsize/2),
|
||||||
target_width=output_size,
|
target_width=water_width_source - (math.floor(windowsize/2)*2),
|
||||||
target_height=output_size
|
target_height=water_height_source - (math.floor(windowsize/2)*2)
|
||||||
)
|
)
|
||||||
|
|
||||||
print("DEBUG:dataset BEFORE_SQUEEZE water", water.shape)
|
print("DEBUG:dataset BEFORE_SQUEEZE water", water.shape)
|
||||||
|
@ -96,7 +83,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
||||||
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
|
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
|
||||||
# LOSS dice
|
# LOSS dice
|
||||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
||||||
|
|
||||||
rainfall = tf.image.extract_patches(tf.expand_dims(rainfall, axis=0),
|
rainfall = tf.image.extract_patches(tf.expand_dims(rainfall, axis=0),
|
||||||
sizes=[1,windowsize,windowsize],
|
sizes=[1,windowsize,windowsize],
|
||||||
|
@ -106,7 +93,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
)
|
)
|
||||||
rainfall = tf.reshape(rainfall, [-1, windowsize, windowsize, rainfall_channels])
|
rainfall = tf.reshape(rainfall, [-1, windowsize, windowsize, rainfall_channels])
|
||||||
|
|
||||||
# TODO: extract single water values here to match the above rainfall patches
|
water = tf.reshape([-1]) # we flatten because we cropped to the right shape above
|
||||||
|
|
||||||
print("DEBUG DATASET_OUT:rainfall shape", rainfall.shape)
|
print("DEBUG DATASET_OUT:rainfall shape", rainfall.shape)
|
||||||
print("DEBUG DATASET_OUT:water shape", water.shape)
|
print("DEBUG DATASET_OUT:water shape", water.shape)
|
||||||
|
@ -114,7 +101,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
|
|
||||||
return tf.function(parse_item_inner)
|
return tf.function(parse_item_inner)
|
||||||
|
|
||||||
def make_dataset(filepaths, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, prefetch=True, shuffle=True, filepath_heightmap=None, **kwargs):
|
def make_dataset(filepaths, compression_type="GZIP", parallel_reads_multiplier=3, shuffle_buffer_size=2**16, batch_size=64, prefetch=True, shuffle=True, filepath_heightmap=None, **kwargs):
|
||||||
if "NO_PREFETCH" in os.environ:
|
if "NO_PREFETCH" in os.environ:
|
||||||
logger.info("disabling data prefetching.")
|
logger.info("disabling data prefetching.")
|
||||||
|
|
||||||
|
@ -128,10 +115,16 @@ def make_dataset(filepaths, compression_type="GZIP", parallel_reads_multiplier=1
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) if parallel_reads_multiplier > 0 else None
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) if parallel_reads_multiplier > 0 else None
|
||||||
)
|
)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
dataset = dataset.shuffle(shuffle_buffer_size)
|
dataset = dataset.shuffle(128) # additional shuffle buffer to mix things up
|
||||||
|
|
||||||
dataset = dataset.map(parse_item(heightmap=heightmap, **kwargs), num_parallel_calls=tf.data.AUTOTUNE) \
|
dataset = dataset.map(parse_item(heightmap=heightmap, **kwargs), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||||
.unbatch()
|
.unbatch()
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
# memory used = (windowsize*windowsize + 1) * buffersize * channels
|
||||||
|
# defaults = (33*33 + 1) * 2**16 * 8 = about 2.219GiB
|
||||||
|
dataset = dataset.shuffle(shuffle_buffer_size)
|
||||||
|
|
||||||
if batch_size != None:
|
if batch_size != None:
|
||||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||||
if prefetch:
|
if prefetch:
|
||||||
|
@ -152,7 +145,7 @@ def get_filepaths(dirpath_input, do_shuffle=True):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs):
|
def dataset_encoderonly(dirpath_input, train_percentage=0.8, **kwargs):
|
||||||
filepaths = get_filepaths(dirpath_input)
|
filepaths = get_filepaths(dirpath_input)
|
||||||
filepaths_count = len(filepaths)
|
filepaths_count = len(filepaths)
|
||||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||||
|
@ -167,7 +160,7 @@ def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs):
|
||||||
|
|
||||||
return dataset_train, dataset_validate #, filepaths
|
return dataset_train, dataset_validate #, filepaths
|
||||||
|
|
||||||
def dataset_mono_predict(dirpath_input, **kwargs):
|
def dataset_encoderonly_predict(dirpath_input, **kwargs):
|
||||||
"""Creates a tf.data.Dataset() for prediction using the contrastive learning model.
|
"""Creates a tf.data.Dataset() for prediction using the contrastive learning model.
|
||||||
Note that this WILL MANGLE THE ORDERING if you set parallel_reads_multiplier to anything other than 0!!
|
Note that this WILL MANGLE THE ORDERING if you set parallel_reads_multiplier to anything other than 0!!
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue