mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
Move dataset parsing function to the right place
This commit is contained in:
parent
50f214450f
commit
b52c7f89a7
2 changed files with 27 additions and 50 deletions
|
@ -9,38 +9,35 @@ import tensorflow as tf
|
||||||
|
|
||||||
from shuffle import shuffle
|
from shuffle import shuffle
|
||||||
|
|
||||||
def parse_line(line):
|
|
||||||
if tf.strings.length(line) <= 0:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
# Yes, this is really what the function is called that converts a string tensor to a regular python string.....
|
|
||||||
obj = json.loads(line.numpy())
|
|
||||||
except:
|
|
||||||
logger.warn("Ignoring invalid line.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
rainfall = tf.constant(obj.rainfallradar, dtype=tf.float32)
|
# TO PARSE:
|
||||||
waterdepth = tf.constant(obj.waterdepth, dtype=tf.float32)
|
@tf.function
|
||||||
|
def parse_item(item):
|
||||||
|
parsed = tf.io.parse_single_example(item, features={
|
||||||
|
"rainfallradar": tf.io.FixedLenFeature([], tf.string),
|
||||||
|
"waterdepth": tf.io.FixedLenFeature([], tf.string)
|
||||||
|
})
|
||||||
|
rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32)
|
||||||
|
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
|
||||||
|
|
||||||
# Inputs, dummy label since we'll be using semi-supervised contrastive learning
|
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
||||||
return rainfall, waterdepth
|
|
||||||
|
|
||||||
def make_dataset(filepaths, batch_size, shuffle_buffer_size=128, parallel_reads_multiplier=2):
|
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
||||||
return tf.data.TextLineDataset(
|
return rainfall, water
|
||||||
filenames=tf.data.Dataset.from_tensor_slices(filepaths).shuffle(len(filepaths), reshuffle_each_iteration=True),
|
|
||||||
compression_type=tf.constant("GZIP"),
|
def make_dataset(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) # iowait can cause issues - especially on Viper
|
return tf.data.TFRecordDataset(filenames,
|
||||||
# TODO: Get rid of this tf.py_function call somehow, because it acquires the Python Global Interpreter lock, which prevents more than 1 thread to run at a time, and .map() uses threads....
|
compression_type=compression_type,
|
||||||
).map(tf.py_function(parse_line), num_parallel_calls=tf.data.AUTOTUNE) \
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||||
.filter(lambda item : item is not None) \
|
).shuffle(shuffle_buffer_size) \
|
||||||
.shuffle(1) \
|
.map(parse_item, num_parallel_calls=tf.data.AUTOTUNE) \
|
||||||
.batch(batch_size) \
|
.batch(batch_size) \
|
||||||
.prefetch(tf.data.AUTOTUNE)
|
.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
|
||||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
|
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||||
filepaths = shuffle(list(filter(
|
filepaths = shuffle(list(filter(
|
||||||
lambda filepath: str(filepath).endswith(".jsonl.gz"),
|
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||||
)))
|
)))
|
||||||
filepaths_count = len(filepaths)
|
filepaths_count = len(filepaths)
|
||||||
|
@ -49,8 +46,9 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8):
|
||||||
filepaths_train = filepaths[:dataset_splitpoint]
|
filepaths_train = filepaths[:dataset_splitpoint]
|
||||||
filepaths_validate = filepaths[dataset_splitpoint:]
|
filepaths_validate = filepaths[dataset_splitpoint:]
|
||||||
|
|
||||||
dataset_train = make_dataset(filepaths_train, batch_size)
|
dataset_train = make_dataset(filepaths_train, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||||
dataset_validate = make_dataset(filepaths_validate, batch_size)
|
dataset_validate = make_dataset(filepaths_validate, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||||
|
|
||||||
return dataset_train, dataset_validate
|
return dataset_train, dataset_validate
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,27 +11,6 @@ if not os.environ.get("NO_SILENCE"):
|
||||||
silence_tensorflow()
|
silence_tensorflow()
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
# TO PARSE:
|
|
||||||
@tf.function
|
|
||||||
def parse_item(item):
|
|
||||||
parsed = tf.io.parse_single_example(item, features={
|
|
||||||
"rainfallradar": tf.io.FixedLenFeature([], tf.string),
|
|
||||||
"waterdepth": tf.io.FixedLenFeature([], tf.string)
|
|
||||||
})
|
|
||||||
rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32)
|
|
||||||
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
|
|
||||||
|
|
||||||
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
|
||||||
|
|
||||||
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
|
||||||
return rainfall, water
|
|
||||||
|
|
||||||
def parse_example(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5):
|
|
||||||
return tf.data.TFRecordDataset(filenames,
|
|
||||||
compression_type=compression_type,
|
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
|
||||||
).map(parse_item, num_parallel_calls=tf.data.AUTOTUNE)
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Convert a generated .jsonl.gz file to a .tfrecord.gz file")
|
parser = argparse.ArgumentParser(description="Convert a generated .jsonl.gz file to a .tfrecord.gz file")
|
||||||
parser.add_argument("--input", "-i", help="Path to the input file to convert.", required=True)
|
parser.add_argument("--input", "-i", help="Path to the input file to convert.", required=True)
|
||||||
|
|
Loading…
Reference in a new issue