diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index ae5e387..49575b2 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -40,13 +40,15 @@ def parse_item(metadata): return tf.function(parse_item_inner) def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64): + if "NO_PREFETCH" in os.environ: + logger.info("disabling data prefetching.") return tf.data.TFRecordDataset(filenames, compression_type=compression_type, num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) ).shuffle(shuffle_buffer_size) \ .map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \ .batch(batch_size) \ - .prefetch(tf.data.AUTOTUNE) + .prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE) def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):