From 18a7d3674b7906d1b5c5d3da1a2dcb8b62c5b09b Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 26 Jul 2022 19:14:10 +0100 Subject: [PATCH] ai: create (untested) dataset --- aimodel/src/lib/__init__.py | 0 aimodel/src/lib/dataset/__init__.py | 0 aimodel/src/lib/dataset/dataset.py | 59 +++++++++++++++++++++++++++++ aimodel/src/lib/dataset/shuffle.py | 18 +++++++++ aimodel/src/requirements.txt | 1 + 5 files changed, 78 insertions(+) create mode 100644 aimodel/src/lib/__init__.py create mode 100644 aimodel/src/lib/dataset/__init__.py create mode 100644 aimodel/src/lib/dataset/dataset.py create mode 100644 aimodel/src/lib/dataset/shuffle.py create mode 100644 aimodel/src/requirements.txt diff --git a/aimodel/src/lib/__init__.py b/aimodel/src/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aimodel/src/lib/dataset/__init__.py b/aimodel/src/lib/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py new file mode 100644 index 0000000..1bf3c32 --- /dev/null +++ b/aimodel/src/lib/dataset/dataset.py @@ -0,0 +1,59 @@ +import os +import math +import json +from socket import if_nameindex + +from loguru import logger + +import tensorflow as tf + +from .shuffle import shuffle + +def parse_line(line): + if tf.strings.length(line) <= 0: + return None + try: + # Yes, this is really what the function is called that converts a string tensor to a regular python string..... + obj = json.loads(line.numpy()) + except: + logger.warn("Ignoring invalid line.") + return None + + rainfall = tf.constant(obj.rainfallradar, dtype=tf.float32) + waterdepth = tf.constant(obj.waterdepth, dtype=tf.float32) + + # Inputs, dummy label since we'll be using semi-supervised contrastive learning + return rainfall, waterdepth + +def make_dataset(filepaths, batch_size, shuffle_buffer_size=128, parallel_reads_multiplier=2): + return tf.data.TextLineDataset( + filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True), + compression_type=tf.constant("GZIP"), + num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially in Viper + ).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \ + .filter(lambda item : item is not None) \ + .shuffle(shuffle_buffer_size) \ + .batch(batch_size) \ + .prefetch(tf.data.AUTOTUNE) + + +def dataset(dirpath_input, batch_size=64, train_percentage=0.8): + filepaths = shuffle(list(filter( + lambda filepath: filepath.endswith(".jsonl.gz"), + os.scandir(dirpath_input) + ))) + + filepaths_count = len(filepaths) + dataset_splitpoint = math.floor(filepaths_count * train_percentage) + + filepaths_train = filepaths[:dataset_splitpoint] + filepaths_validate = filepaths[dataset_splitpoint:] + + dataset_train = make_dataset(filepaths_train, batch_size) + dataset_validate = make_dataset(filepaths_validate, batch_size) + return dataset_train, dataset_validate + + + +if __name__ == "__main__": + dataset() \ No newline at end of file diff --git a/aimodel/src/lib/dataset/shuffle.py b/aimodel/src/lib/dataset/shuffle.py new file mode 100644 index 0000000..487fe2c --- /dev/null +++ b/aimodel/src/lib/dataset/shuffle.py @@ -0,0 +1,18 @@ +from copy import deepcopy +from random import randint + + +def shuffle(lst): + """ + Shuffles a list with the Fisher-Yates algorithm. + @ref https://poopcode.com/shuffle-a-list-in-python-fisher-yates/ + @param lst list The list to shuffle. + @return list The a new list that is a shuffled copy of the original. + """ + tmplist = deepcopy(lst) + m = len(tmplist) + while (m): + m -= 1 + i = randint(0, m) + tmplist[m], tmplist[i] = tmplist[i], tmplist[m] + return tmplist diff --git a/aimodel/src/requirements.txt b/aimodel/src/requirements.txt new file mode 100644 index 0000000..8145e77 --- /dev/null +++ b/aimodel/src/requirements.txt @@ -0,0 +1 @@ +tensorflow>=2.4 \ No newline at end of file