mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
dataset: simplify dataset_predict
This commit is contained in:
parent
279e27c898
commit
fa3165a5b2
1 changed files with 9 additions and 5 deletions
|
@ -14,7 +14,7 @@ from .shuffle import shuffle
|
||||||
|
|
||||||
|
|
||||||
# TO PARSE:
|
# TO PARSE:
|
||||||
def parse_item(metadata, shape_water_desired):
|
def parse_item(metadata, shape_water_desired, dummy_label=True):
|
||||||
water_width_source, water_height_source = metadata["waterdepth"]
|
water_width_source, water_height_source = metadata["waterdepth"]
|
||||||
water_width_target, water_height_target = shape_water_desired
|
water_width_target, water_height_target = shape_water_desired
|
||||||
water_offset_x = math.ceil((water_width_source - water_width_target) / 2)
|
water_offset_x = math.ceil((water_width_source - water_width_target) / 2)
|
||||||
|
@ -39,18 +39,21 @@ def parse_item(metadata, shape_water_desired):
|
||||||
|
|
||||||
print("DEBUG:dataset ITEM rainfall:shape", rainfall.shape, "water:shape", water.shape)
|
print("DEBUG:dataset ITEM rainfall:shape", rainfall.shape, "water:shape", water.shape)
|
||||||
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
||||||
return ((rainfall, water), tf.ones(1))
|
if dummy_label:
|
||||||
|
return ((rainfall, water), tf.ones(1))
|
||||||
|
else:
|
||||||
|
return rainfall, water
|
||||||
|
|
||||||
return tf.function(parse_item_inner)
|
return tf.function(parse_item_inner)
|
||||||
|
|
||||||
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, dummy_label=True):
|
||||||
if "NO_PREFETCH" in os.environ:
|
if "NO_PREFETCH" in os.environ:
|
||||||
logger.info("disabling data prefetching.")
|
logger.info("disabling data prefetching.")
|
||||||
return tf.data.TFRecordDataset(filepaths,
|
return tf.data.TFRecordDataset(filepaths,
|
||||||
compression_type=compression_type,
|
compression_type=compression_type,
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||||
).shuffle(shuffle_buffer_size) \
|
).shuffle(shuffle_buffer_size) \
|
||||||
.map(parse_item(metadata, shape_water_desired=shape_watch_desired), num_parallel_calls=tf.data.AUTOTUNE) \
|
.map(parse_item(metadata, shape_water_desired=shape_watch_desired, dummy_label=dummy_label), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||||
.batch(batch_size, drop_remainder=True) \
|
.batch(batch_size, drop_remainder=True) \
|
||||||
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
@ -86,7 +89,8 @@ def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5)
|
||||||
filepaths=filepaths,
|
filepaths=filepaths,
|
||||||
metadata=read_metadata(dirpath_input),
|
metadata=read_metadata(dirpath_input),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
parallel_reads_multiplier=parallel_reads_multiplier
|
parallel_reads_multiplier=parallel_reads_multiplier,
|
||||||
|
dummy_label=False
|
||||||
), filepaths[0:filepaths_count], filepaths_count
|
), filepaths[0:filepaths_count], filepaths_count
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue