Add NO_PREFETCH env var

This commit is contained in:
Starbeamrainbowlabs 2022-09-02 17:55:04 +01:00
parent 3e0ca6a315
commit 3d44831080
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -40,13 +40,15 @@ def parse_item(metadata):
return tf.function(parse_item_inner) return tf.function(parse_item_inner)
def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64): def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
if "NO_PREFETCH" in os.environ:
logger.info("disabling data prefetching.")
return tf.data.TFRecordDataset(filenames, return tf.data.TFRecordDataset(filenames,
compression_type=compression_type, compression_type=compression_type,
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
).shuffle(shuffle_buffer_size) \ ).shuffle(shuffle_buffer_size) \
.map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \ .map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \
.batch(batch_size) \ .batch(batch_size) \
.prefetch(tf.data.AUTOTUNE) .prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5): def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):