From fa3165a5b25778c7e0896484e47d99f9ef5d1c15 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 14 Sep 2022 17:33:17 +0100 Subject: [PATCH] dataset: simplify dataset_predict --- aimodel/src/lib/dataset/dataset.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index 403c158..203cc10 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -14,7 +14,7 @@ from .shuffle import shuffle # TO PARSE: -def parse_item(metadata, shape_water_desired): +def parse_item(metadata, shape_water_desired, dummy_label=True): water_width_source, water_height_source = metadata["waterdepth"] water_width_target, water_height_target = shape_water_desired water_offset_x = math.ceil((water_width_source - water_width_target) / 2) @@ -39,18 +39,21 @@ def parse_item(metadata, shape_water_desired): print("DEBUG:dataset ITEM rainfall:shape", rainfall.shape, "water:shape", water.shape) # TODO: Any other additional parsing here, since multiple .map() calls are not optimal - return ((rainfall, water), tf.ones(1)) + if dummy_label: + return ((rainfall, water), tf.ones(1)) + else: + return rainfall, water return tf.function(parse_item_inner) -def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64): +def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, dummy_label=True): if "NO_PREFETCH" in os.environ: logger.info("disabling data prefetching.") return tf.data.TFRecordDataset(filepaths, compression_type=compression_type, num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) ).shuffle(shuffle_buffer_size) \ - .map(parse_item(metadata, shape_water_desired=shape_watch_desired), num_parallel_calls=tf.data.AUTOTUNE) \ + .map(parse_item(metadata, shape_water_desired=shape_watch_desired, dummy_label=dummy_label), num_parallel_calls=tf.data.AUTOTUNE) \ .batch(batch_size, drop_remainder=True) \ .prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE) @@ -86,7 +89,8 @@ def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5) filepaths=filepaths, metadata=read_metadata(dirpath_input), batch_size=batch_size, - parallel_reads_multiplier=parallel_reads_multiplier + parallel_reads_multiplier=parallel_reads_multiplier, + dummy_label=False ), filepaths[0:filepaths_count], filepaths_count if __name__ == "__main__":