encoderonly model: getting there

This commit is contained in:
Starbeamrainbowlabs 2023-01-09 19:33:41 +00:00
parent 4b7df39fac
commit 1958c4e6c2
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 102 additions and 29 deletions

View file

@ -4,8 +4,9 @@ import os
import tensorflow as tf
from lib.dataset.dataset_encoderonly import dataset_encoderonly
from lib.ai.components.convnext import make_convnext
from lib.ai.helpers.summarywriter import summarywriter
# ███████ ███ ██ ██ ██
# ██ ████ ██ ██ ██
@ -14,8 +15,27 @@ from lib.ai.components.convnext import make_convnext
# ███████ ██ ████ ████
# TODO: env vars & settings here
DIRPATH_INPUT = os.environ["DIRPATH_INPUT"]
DIRPATH_OUTPUT = os.environ["DIRPATH_OUTPUT"]
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] if "PATH_HEIGHTMAP" in os.environ else None
CHANNELS = os.environ["CHANNELS"] if "CHANNELS" in os.environ else 8
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ else 33
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
logger.info("Encoder-only rainfall radar TEST")
logger.info(f"> DIRPATH_INPUT {DIRPATH_INPUT}")
logger.info(f"> DIRPATH_OUTPUT {DIRPATH_OUTPUT}")
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
logger.info(f"> CHANNELS {CHANNELS}")
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
logger.info(f"> WINDOW_SIZE {WINDOW_SIZE}")
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
if not os.path.exists(DIRPATH_OUTPUT):
os.makedirs(os.path.join(DIRPATH_OUTPUT, "checkpoints"))
# ██████ █████ ████████ █████ ███████ ███████ ████████
@ -24,7 +44,12 @@ from lib.ai.components.convnext import make_convnext
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
dataset_train, dataset_validate = dataset_encoderonly(
dirpath_input=DIRPATH_INPUT,
filepath_heightmap=PATH_HEIGHTMAP,
batch_size=BATCH_SIZE,
windowsize=WINDOW_SIZE
)
# ███ ███ ██████ ██████ ███████ ██
@ -33,9 +58,42 @@ from lib.ai.components.convnext import make_convnext
# ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██████ ██████ ███████ ███████
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2):
if encoder == "convnext":
model = make_convnext(input_shape=(windowsize, windowsize, channels), num_classes=water_bins, **kwargs)
elif encoder == "resnet":
layer_in = tf.keras.Input(shape=(windowsize, windowsize, channels))
layer_next = tf.keras.applications.resnet50.ResNet50(
weights=None,
include_top=True,
classes=water_bins,
input_tensor=layer_in,
pooling="max",
)
model = tf.keras.Model(
inputs = layer_in,
outputs = layer_next
)
else:
raise Exception(f"Error: Unknown encoder '{encoder}' (known encoders: convnext, resnet).")
model.compile(
optimizer="Adam",
loss = tf.keras.losses.SparseCategoricalCrossentropy(),
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy()
]
)
return model
# TODO: Have an
model = make_encoderonly(
windowsize=WINDOW_SIZE,
channels=CHANNELS
)
summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt"))
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
@ -43,6 +101,28 @@ from lib.ai.components.convnext import make_convnext
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
model.fit(dataset_train,
validation_data=dataset_validate,
epochs=25,
callbacks=[
tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIRPATH_OUTPUT, "metrics.tsv"),
separator="\t"
),
CallbackCustomModelCheckpoint(
model_to_checkpoint=model,
filepath=os.path.join(
DIRPATH_OUTPUT,
"checkpoints"
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
),
monitor="loss"
),
],
steps_per_epoch=STEPS_PER_EPOCH,
)
logger.info(">>> Training complete")
# ██████ ██████ ███████ ██████ ██ ██████ ████████ ██ ██████ ███ ██
@ -52,7 +132,7 @@ from lib.ai.components.convnext import make_convnext
# ██ ██ ██ ███████ ██████ ██ ██████ ██ ██ ██████ ██ ████
# TODO FILL THIS IN

View file

@ -15,19 +15,15 @@ from .parse_heightmap import parse_heightmap
# TO PARSE:
def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1):
def parse_item(metadata, water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1, windowsize=33):
if input_size == "same":
input_size = output_size # This is almost always the case with e.g. the DeepLabV3+ model
water_height_source, water_width_source = metadata["waterdepth"]
water_offset_x = math.ceil((water_width_source - output_size) / 2)
water_offset_y = math.ceil((water_height_source - output_size) / 2)
rainfall_channels, rainfall_height_source, rainfall_width_source = metadata["rainfallradar"]
rainfall_height_source *= rainfall_scale_up
rainfall_width_source *= rainfall_scale_up
rainfall_offset_x = math.ceil((rainfall_width_source - input_size) / 2)
rainfall_offset_y = math.ceil((rainfall_height_source - input_size) / 2)
print("DEBUG DATASET:rainfall shape", metadata["rainfallradar"], "/", f"w {rainfall_width_source} h {rainfall_height_source}")
print("DEBUG DATASET:water shape", metadata["waterdepth"])
@ -35,8 +31,6 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
print("DEBUG DATASET:water_bins", water_bins)
print("DEBUG DATASET:output_size", output_size)
print("DEBUG DATASET:input_size", input_size)
print("DEBUG DATASET:water_offset x", water_offset_x, "y", water_offset_y)
print("DEBUG DATASET:rainfall_offset x", rainfall_offset_x, "y", rainfall_offset_y)
if heightmap is not None:
heightmap = tf.expand_dims(heightmap, axis=-1)
@ -72,21 +66,14 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
rainfall = tf.concat([rainfall, heightmap], axis=-1)
if rainfall_scale_up > 1:
rainfall = tf.repeat(tf.repeat(rainfall, rainfall_scale_up, axis=0), rainfall_scale_up, axis=1)
if input_size is not None:
rainfall = tf.image.crop_to_bounding_box(rainfall,
offset_width=rainfall_offset_x,
offset_height=rainfall_offset_y,
target_width=input_size,
target_height=input_size,
)
# rainfall = tf.image.resize(rainfall, tf.cast(tf.constant(metadata["rainfallradar"]) / 2, dtype=tf.int32))
water = tf.expand_dims(water, axis=-1) # [height, width] → [height, width, channels=1]
water = tf.image.crop_to_bounding_box(water,
offset_width=water_offset_x,
offset_height=water_offset_y,
target_width=output_size,
target_height=output_size
water = tf.image.crop_to_bounding_box(water, # for path generation; we only want the water bit where the patches convolve
offset_width=math.floor(windowsize/2),
offset_height=math.floor(windowsize/2),
target_width=water_width_source - (math.floor(windowsize/2)*2),
target_height=water_height_source - (math.floor(windowsize/2)*2)
)
print("DEBUG:dataset BEFORE_SQUEEZE water", water.shape)
@ -96,7 +83,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
# LOSS dice
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
rainfall = tf.image.extract_patches(tf.expand_dims(rainfall, axis=0),
sizes=[1,windowsize,windowsize],
@ -106,7 +93,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
)
rainfall = tf.reshape(rainfall, [-1, windowsize, windowsize, rainfall_channels])
# TODO: extract single water values here to match the above rainfall patches
water = tf.reshape([-1]) # we flatten because we cropped to the right shape above
print("DEBUG DATASET_OUT:rainfall shape", rainfall.shape)
print("DEBUG DATASET_OUT:water shape", water.shape)
@ -114,7 +101,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
return tf.function(parse_item_inner)
def make_dataset(filepaths, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, prefetch=True, shuffle=True, filepath_heightmap=None, **kwargs):
def make_dataset(filepaths, compression_type="GZIP", parallel_reads_multiplier=3, shuffle_buffer_size=2**16, batch_size=64, prefetch=True, shuffle=True, filepath_heightmap=None, **kwargs):
if "NO_PREFETCH" in os.environ:
logger.info("disabling data prefetching.")
@ -128,10 +115,16 @@ def make_dataset(filepaths, compression_type="GZIP", parallel_reads_multiplier=1
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) if parallel_reads_multiplier > 0 else None
)
if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.shuffle(128) # additional shuffle buffer to mix things up
dataset = dataset.map(parse_item(heightmap=heightmap, **kwargs), num_parallel_calls=tf.data.AUTOTUNE) \
.unbatch()
if shuffle:
# memory used = (windowsize*windowsize + 1) * buffersize * channels
# defaults = (33*33 + 1) * 2**16 * 8 = about 2.219GiB
dataset = dataset.shuffle(shuffle_buffer_size)
if batch_size != None:
dataset = dataset.batch(batch_size, drop_remainder=True)
if prefetch:
@ -152,7 +145,7 @@ def get_filepaths(dirpath_input, do_shuffle=True):
return result
def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs):
def dataset_encoderonly(dirpath_input, train_percentage=0.8, **kwargs):
filepaths = get_filepaths(dirpath_input)
filepaths_count = len(filepaths)
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
@ -167,7 +160,7 @@ def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs):
return dataset_train, dataset_validate #, filepaths
def dataset_mono_predict(dirpath_input, **kwargs):
def dataset_encoderonly_predict(dirpath_input, **kwargs):
"""Creates a tf.data.Dataset() for prediction using the contrastive learning model.
Note that this WILL MANGLE THE ORDERING if you set parallel_reads_multiplier to anything other than 0!!