From d5f1a26ba3ad2aa7c16a7fe875b3cdd41afff8ef Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 15 Sep 2022 17:09:09 +0100 Subject: [PATCH] disable prefetching when predicting a thing --- aimodel/src/lib/dataset/dataset.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index 203cc10..afa8d27 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -46,16 +46,21 @@ def parse_item(metadata, shape_water_desired, dummy_label=True): return tf.function(parse_item_inner) -def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, dummy_label=True): +def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, dummy_label=True, prefetch=True): if "NO_PREFETCH" in os.environ: logger.info("disabling data prefetching.") - return tf.data.TFRecordDataset(filepaths, + + dataset = tf.data.TFRecordDataset(filepaths, compression_type=compression_type, num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) ).shuffle(shuffle_buffer_size) \ .map(parse_item(metadata, shape_water_desired=shape_watch_desired, dummy_label=dummy_label), num_parallel_calls=tf.data.AUTOTUNE) \ - .batch(batch_size, drop_remainder=True) \ - .prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE) + .batch(batch_size, drop_remainder=True) + + if prefetch: + dataset = dataset.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE) + + return dataset def get_filepaths(dirpath_input): @@ -79,7 +84,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m return dataset_train, dataset_validate #, filepaths -def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5): +def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5, pretrain=False): filepaths = get_filepaths(dirpath_input) filepaths_count = len(filepaths) for i in range(len(filepaths)): @@ -90,7 +95,8 @@ def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5) metadata=read_metadata(dirpath_input), batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, - dummy_label=False + dummy_label=False, + pretrain=pretrain ), filepaths[0:filepaths_count], filepaths_count if __name__ == "__main__":