mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
Add NO_PREFETCH env var
This commit is contained in:
parent
3e0ca6a315
commit
3d44831080
1 changed files with 3 additions and 1 deletions
|
@ -40,13 +40,15 @@ def parse_item(metadata):
|
||||||
return tf.function(parse_item_inner)
|
return tf.function(parse_item_inner)
|
||||||
|
|
||||||
def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||||
|
if "NO_PREFETCH" in os.environ:
|
||||||
|
logger.info("disabling data prefetching.")
|
||||||
return tf.data.TFRecordDataset(filenames,
|
return tf.data.TFRecordDataset(filenames,
|
||||||
compression_type=compression_type,
|
compression_type=compression_type,
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||||
).shuffle(shuffle_buffer_size) \
|
).shuffle(shuffle_buffer_size) \
|
||||||
.map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \
|
.map(parse_item(metadata), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||||
.batch(batch_size) \
|
.batch(batch_size) \
|
||||||
.prefetch(tf.data.AUTOTUNE)
|
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
|
||||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||||
|
|
Loading…
Reference in a new issue