From 323d708692c35cb7fd30d62c78ac51ee1c9c0a2e Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 26 Jul 2022 19:53:18 +0100 Subject: [PATCH] dataset: add todo just why, Tensorflow?! tf.data.TextLineDataset looks almost too good to be true..... and it is, as despite supporting decompressing via gzip(!) it doesn't look like we can convince it to parse JSON :-/ --- aimodel/src/lib/dataset/dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index d5ca3c6..daa756a 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -7,7 +7,7 @@ from loguru import logger import tensorflow as tf -from .shuffle import shuffle +from shuffle import shuffle def parse_line(line): if tf.strings.length(line) <= 0: @@ -30,19 +30,19 @@ def make_dataset(filepaths, batch_size, shuffle_buffer_size=128, parallel_reads_ filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True), compression_type=tf.constant("GZIP"), num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially on Viper + # TODO: Get rid of this tf.py_function call somehow, because it acquires the Python Global Interpreter lock, which prevents more than 1 thread to run at a time, and .map() uses threads.... ).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \ .filter(lambda item : item is not None) \ - .shuffle(shuffle_buffer_size) \ + .shuffle(1) \ .batch(batch_size) \ .prefetch(tf.data.AUTOTUNE) def dataset(dirpath_input, batch_size=64, train_percentage=0.8): filepaths = shuffle(list(filter( - lambda filepath: filepath.endswith(".jsonl.gz"), - os.scandir(dirpath_input) + lambda filepath: str(filepath).endswith(".jsonl.gz"), + [ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath ))) - filepaths_count = len(filepaths) dataset_splitpoint = math.floor(filepaths_count * train_percentage) @@ -56,7 +56,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8): if __name__ == "__main__": - ds_train, ds_validate = dataset() + ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/") for thing in ds_validate(): as_str = str(thing) - print(thing[:100]) \ No newline at end of file + print(thing[:200]) \ No newline at end of file