diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index d5ca3c6..daa756a 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -7,7 +7,7 @@ from loguru import logger import tensorflow as tf -from .shuffle import shuffle +from shuffle import shuffle def parse_line(line): if tf.strings.length(line) <= 0: @@ -30,19 +30,19 @@ def make_dataset(filepaths, batch_size, shuffle_buffer_size=128, parallel_reads_ filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True), compression_type=tf.constant("GZIP"), num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially on Viper + # TODO: Get rid of this tf.py_function call somehow, because it acquires the Python Global Interpreter lock, which prevents more than 1 thread to run at a time, and .map() uses threads.... ).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \ .filter(lambda item : item is not None) \ - .shuffle(shuffle_buffer_size) \ + .shuffle(1) \ .batch(batch_size) \ .prefetch(tf.data.AUTOTUNE) def dataset(dirpath_input, batch_size=64, train_percentage=0.8): filepaths = shuffle(list(filter( - lambda filepath: filepath.endswith(".jsonl.gz"), - os.scandir(dirpath_input) + lambda filepath: str(filepath).endswith(".jsonl.gz"), + [ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath ))) - filepaths_count = len(filepaths) dataset_splitpoint = math.floor(filepaths_count * train_percentage) @@ -56,7 +56,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8): if __name__ == "__main__": - ds_train, ds_validate = dataset() + ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/") for thing in ds_validate(): as_str = str(thing) - print(thing[:100]) \ No newline at end of file + print(thing[:200]) \ No newline at end of file