dataset: add todo

just why, Tensorflow?!
tf.data.TextLineDataset looks almost too good to be true..... and it is, as despite supporting decompressing via gzip(!) it doesn't look like we can convince it to parse JSON :-/
This commit is contained in:
Starbeamrainbowlabs 2022-07-26 19:53:18 +01:00
parent b53c77a2cb
commit 323d708692
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -7,7 +7,7 @@ from loguru import logger
import tensorflow as tf
from .shuffle import shuffle
from shuffle import shuffle
def parse_line(line):
if tf.strings.length(line) <= 0:
@ -30,19 +30,19 @@ def make_dataset(filepaths, batch_size, shuffle_buffer_size=128, parallel_reads_
filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True),
compression_type=tf.constant("GZIP"),
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially on Viper
# TODO: Get rid of this tf.py_function call somehow, because it acquires the Python Global Interpreter lock, which prevents more than 1 thread to run at a time, and .map() uses threads....
).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \
.filter(lambda item : item is not None) \
.shuffle(shuffle_buffer_size) \
.shuffle(1) \
.batch(batch_size) \
.prefetch(tf.data.AUTOTUNE)
def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
filepaths = shuffle(list(filter(
lambda filepath: filepath.endswith(".jsonl.gz"),
os.scandir(dirpath_input)
lambda filepath: str(filepath).endswith(".jsonl.gz"),
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
)))
filepaths_count = len(filepaths)
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
@ -56,7 +56,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
if __name__ == "__main__":
ds_train, ds_validate = dataset()
ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/")
for thing in ds_validate():
as_str = str(thing)
print(thing[:100])
print(thing[:200])