ai: these shapes are so annoying

This commit is contained in:
Starbeamrainbowlabs 2022-09-02 18:39:24 +01:00
parent 88acd54a97
commit c33c0a0899
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
4 changed files with 28 additions and 14 deletions

View file

@ -59,7 +59,11 @@ class RainfallWaterContraster(object):
def make_model(self):
return model_rainfallwater_contrastive(batch_size=self.batch_size, summary_file=self.filepath_summary, **self.kwargs)
return model_rainfallwater_contrastive(
batch_size=self.batch_size,
summary_file=self.filepath_summary,
**self.kwargs
)
def load_model(self, filepath_checkpoint):

View file

@ -1,4 +1,5 @@
from curses import meta
from loguru import logger
import tensorflow as tf
@ -6,15 +7,20 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
from .components.LossContrastive import LossContrastive
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
logger.info(shape_rainfall)
logger.info(shape_water)
def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
# Shapes come from what rainfallwrangler sees them as, but we add an extra dimension when reading the .tfrecord file
rainfall_channels, rainfall_width, rainfall_height = shape_rainfall # shape = [channels, width, height]
rainfall_channels, rainfall_width, rainfall_height = metadata["rainfallradar"] # shape = [channels, width, height]
water_width, water_height = shape_water # shape = [width, height]
water_channels = 1 # added in dataset → make_dataset → parse_item
rainfall_width, rainfall_height = rainfall_width / 2, rainfall_height / 2
logger.info("SOURCE shape_rainfall " + str(metadata["rainfallradar"]))
logger.info("SOURCE shape_water " + str(metadata["waterdepth"]))
logger.info("TARGET shape_water" + str(shape_water))
logger.info("TARGET shape_rainfall" + str([ rainfall_width, rainfall_height, rainfall_channels ]))
input_rainfall = tf.keras.layers.Input(
shape=(rainfall_width, rainfall_height, rainfall_channels)
)

View file

@ -14,7 +14,11 @@ from .shuffle import shuffle
# TO PARSE:
def parse_item(metadata):
def parse_item(metadata, shape_water_desired):
water_width_source, water_height_source = metadata["waterdepth"]
water_width_target, water_height_target = shape_water_desired
water_offset_x = math.ceil((water_width_source - water_width_target) / 2)
water_offset_y = math.ceil((water_height_source - water_height_target) / 2)
def parse_item_inner(item):
parsed = tf.io.parse_single_example(item, features={
"rainfallradar": tf.io.FixedLenFeature([], tf.string),
@ -29,9 +33,9 @@ def parse_item(metadata):
rainfall = tf.transpose(rainfall, [1, 2, 0])
rainfall = tf.image.resize(rainfall, tf.cast(tf.constant(metadata["waterdepth"]) / 2, dtype=tf.int32))
# [width, height] → [width, height, channels]
water = tf.expand_dims(water, axis=-1)
water = tf.image.central_crop(water, 0.5) # Predict for only the centre 75% of the water data
water = tf.image.crop_to_bounding_box(water, water_offset_x, water_offset_y, water_width_target, water_height_target)
water = tf.expand_dims(water, axis=-1) # [width, height] → [width, height, channels]
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
print("DEBUG:dataset ITEM rainfall:shape", rainfall.shape, "water:shape", water.shape)
@ -40,14 +44,14 @@ def parse_item(metadata):
return tf.function(parse_item_inner)
def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
if "NO_PREFETCH" in os.environ:
logger.info("disabling data prefetching.")
return tf.data.TFRecordDataset(filenames,
compression_type=compression_type,
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
).shuffle(shuffle_buffer_size) \
.map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \
.map(parse_item(metadata, shape_water_desired=shape_watch_desired), num_parallel_calls=tf.data.AUTOTUNE) \
.batch(batch_size) \
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)

View file

@ -52,8 +52,8 @@ def run(args):
batch_size=args.batch_size,
feature_dim=args.feature_dim,
shape_rainfall=dataset_metadata["rainfallradar"],
shape_water=[ math.ceil(value * 0.5) + 1 for value in dataset_metadata["waterdepth"] ]
metadata = read_metadata(args.input),
shape_water=[ 100, 100 ] # The DESIRED
)
ai.train(dataset_train, dataset_validate)