mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
disable prefetching when predicting a thing
This commit is contained in:
parent
8770638022
commit
d5f1a26ba3
1 changed files with 12 additions and 6 deletions
|
@ -46,16 +46,21 @@ def parse_item(metadata, shape_water_desired, dummy_label=True):
|
|||
|
||||
return tf.function(parse_item_inner)
|
||||
|
||||
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, dummy_label=True):
|
||||
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, dummy_label=True, prefetch=True):
|
||||
if "NO_PREFETCH" in os.environ:
|
||||
logger.info("disabling data prefetching.")
|
||||
return tf.data.TFRecordDataset(filepaths,
|
||||
|
||||
dataset = tf.data.TFRecordDataset(filepaths,
|
||||
compression_type=compression_type,
|
||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||
).shuffle(shuffle_buffer_size) \
|
||||
.map(parse_item(metadata, shape_water_desired=shape_watch_desired, dummy_label=dummy_label), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||
.batch(batch_size, drop_remainder=True) \
|
||||
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||
.batch(batch_size, drop_remainder=True)
|
||||
|
||||
if prefetch:
|
||||
dataset = dataset.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_filepaths(dirpath_input):
|
||||
|
@ -79,7 +84,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
|||
|
||||
return dataset_train, dataset_validate #, filepaths
|
||||
|
||||
def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5):
|
||||
def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5, pretrain=False):
|
||||
filepaths = get_filepaths(dirpath_input)
|
||||
filepaths_count = len(filepaths)
|
||||
for i in range(len(filepaths)):
|
||||
|
@ -90,7 +95,8 @@ def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5)
|
|||
metadata=read_metadata(dirpath_input),
|
||||
batch_size=batch_size,
|
||||
parallel_reads_multiplier=parallel_reads_multiplier,
|
||||
dummy_label=False
|
||||
dummy_label=False,
|
||||
pretrain=pretrain
|
||||
), filepaths[0:filepaths_count], filepaths_count
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue