mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
ai: these shapes are so annoying
This commit is contained in:
parent
88acd54a97
commit
c33c0a0899
4 changed files with 28 additions and 14 deletions
|
@ -59,7 +59,11 @@ class RainfallWaterContraster(object):
|
|||
|
||||
|
||||
def make_model(self):
|
||||
return model_rainfallwater_contrastive(batch_size=self.batch_size, summary_file=self.filepath_summary, **self.kwargs)
|
||||
return model_rainfallwater_contrastive(
|
||||
batch_size=self.batch_size,
|
||||
summary_file=self.filepath_summary,
|
||||
**self.kwargs
|
||||
)
|
||||
|
||||
|
||||
def load_model(self, filepath_checkpoint):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
|
||||
from curses import meta
|
||||
from loguru import logger
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -6,15 +7,20 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
|
|||
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
|
||||
from .components.LossContrastive import LossContrastive
|
||||
|
||||
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
|
||||
logger.info(shape_rainfall)
|
||||
logger.info(shape_water)
|
||||
|
||||
def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
|
||||
# Shapes come from what rainfallwrangler sees them as, but we add an extra dimension when reading the .tfrecord file
|
||||
rainfall_channels, rainfall_width, rainfall_height = shape_rainfall # shape = [channels, width, height]
|
||||
rainfall_channels, rainfall_width, rainfall_height = metadata["rainfallradar"] # shape = [channels, width, height]
|
||||
water_width, water_height = shape_water # shape = [width, height]
|
||||
water_channels = 1 # added in dataset → make_dataset → parse_item
|
||||
|
||||
rainfall_width, rainfall_height = rainfall_width / 2, rainfall_height / 2
|
||||
|
||||
logger.info("SOURCE shape_rainfall " + str(metadata["rainfallradar"]))
|
||||
logger.info("SOURCE shape_water " + str(metadata["waterdepth"]))
|
||||
logger.info("TARGET shape_water" + str(shape_water))
|
||||
logger.info("TARGET shape_rainfall" + str([ rainfall_width, rainfall_height, rainfall_channels ]))
|
||||
|
||||
|
||||
input_rainfall = tf.keras.layers.Input(
|
||||
shape=(rainfall_width, rainfall_height, rainfall_channels)
|
||||
)
|
||||
|
|
|
@ -14,7 +14,11 @@ from .shuffle import shuffle
|
|||
|
||||
|
||||
# TO PARSE:
|
||||
def parse_item(metadata):
|
||||
def parse_item(metadata, shape_water_desired):
|
||||
water_width_source, water_height_source = metadata["waterdepth"]
|
||||
water_width_target, water_height_target = shape_water_desired
|
||||
water_offset_x = math.ceil((water_width_source - water_width_target) / 2)
|
||||
water_offset_y = math.ceil((water_height_source - water_height_target) / 2)
|
||||
def parse_item_inner(item):
|
||||
parsed = tf.io.parse_single_example(item, features={
|
||||
"rainfallradar": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -29,9 +33,9 @@ def parse_item(metadata):
|
|||
|
||||
rainfall = tf.transpose(rainfall, [1, 2, 0])
|
||||
rainfall = tf.image.resize(rainfall, tf.cast(tf.constant(metadata["waterdepth"]) / 2, dtype=tf.int32))
|
||||
# [width, height] → [width, height, channels]
|
||||
water = tf.expand_dims(water, axis=-1)
|
||||
water = tf.image.central_crop(water, 0.5) # Predict for only the centre 75% of the water data
|
||||
|
||||
water = tf.image.crop_to_bounding_box(water, water_offset_x, water_offset_y, water_width_target, water_height_target)
|
||||
water = tf.expand_dims(water, axis=-1) # [width, height] → [width, height, channels]
|
||||
|
||||
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
||||
print("DEBUG:dataset ITEM rainfall:shape", rainfall.shape, "water:shape", water.shape)
|
||||
|
@ -40,14 +44,14 @@ def parse_item(metadata):
|
|||
|
||||
return tf.function(parse_item_inner)
|
||||
|
||||
def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||
def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||
if "NO_PREFETCH" in os.environ:
|
||||
logger.info("disabling data prefetching.")
|
||||
return tf.data.TFRecordDataset(filenames,
|
||||
compression_type=compression_type,
|
||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||
).shuffle(shuffle_buffer_size) \
|
||||
.map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||
.map(parse_item(metadata, shape_water_desired=shape_watch_desired), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||
.batch(batch_size) \
|
||||
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||
|
||||
|
|
|
@ -52,8 +52,8 @@ def run(args):
|
|||
batch_size=args.batch_size,
|
||||
feature_dim=args.feature_dim,
|
||||
|
||||
shape_rainfall=dataset_metadata["rainfallradar"],
|
||||
shape_water=[ math.ceil(value * 0.5) + 1 for value in dataset_metadata["waterdepth"] ]
|
||||
metadata = read_metadata(args.input),
|
||||
shape_water=[ 100, 100 ] # The DESIRED
|
||||
)
|
||||
|
||||
ai.train(dataset_train, dataset_validate)
|
||||
|
|
Loading…
Reference in a new issue