mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
dataset: add todo
just why, Tensorflow?! tf.data.TextLineDataset looks almost too good to be true..... and it is, as despite supporting decompressing via gzip(!) it doesn't look like we can convince it to parse JSON :-/
This commit is contained in:
parent
b53c77a2cb
commit
323d708692
1 changed files with 7 additions and 7 deletions
|
@ -7,7 +7,7 @@ from loguru import logger
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .shuffle import shuffle
|
from shuffle import shuffle
|
||||||
|
|
||||||
def parse_line(line):
|
def parse_line(line):
|
||||||
if tf.strings.length(line) <= 0:
|
if tf.strings.length(line) <= 0:
|
||||||
|
@ -30,19 +30,19 @@ def make_dataset(filepaths, batch_size, shuffle_buffer_size=128, parallel_reads_
|
||||||
filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True),
|
filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True),
|
||||||
compression_type=tf.constant("GZIP"),
|
compression_type=tf.constant("GZIP"),
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially on Viper
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially on Viper
|
||||||
|
# TODO: Get rid of this tf.py_function call somehow, because it acquires the Python Global Interpreter lock, which prevents more than 1 thread to run at a time, and .map() uses threads....
|
||||||
).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \
|
).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \
|
||||||
.filter(lambda item : item is not None) \
|
.filter(lambda item : item is not None) \
|
||||||
.shuffle(shuffle_buffer_size) \
|
.shuffle(1) \
|
||||||
.batch(batch_size) \
|
.batch(batch_size) \
|
||||||
.prefetch(tf.data.AUTOTUNE)
|
.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
|
||||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
|
def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
|
||||||
filepaths = shuffle(list(filter(
|
filepaths = shuffle(list(filter(
|
||||||
lambda filepath: filepath.endswith(".jsonl.gz"),
|
lambda filepath: str(filepath).endswith(".jsonl.gz"),
|
||||||
os.scandir(dirpath_input)
|
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||||
)))
|
)))
|
||||||
|
|
||||||
filepaths_count = len(filepaths)
|
filepaths_count = len(filepaths)
|
||||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ds_train, ds_validate = dataset()
|
ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/")
|
||||||
for thing in ds_validate():
|
for thing in ds_validate():
|
||||||
as_str = str(thing)
|
as_str = str(thing)
|
||||||
print(thing[:100])
|
print(thing[:200])
|
Loading…
Reference in a new issue