mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
get python bridge working t convert .jsonl.gz → .tfrecord.gz
This commit is contained in:
parent
28a3f578d5
commit
a02c3436ab
1 changed files with 62 additions and 8 deletions
70
rainfallwrangler/src/lib/python/json2tfrecord.py
Normal file → Executable file
70
rainfallwrangler/src/lib/python/json2tfrecord.py
Normal file → Executable file
|
@ -1,34 +1,82 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import gzip
|
||||
import json
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# TO PARSE:
|
||||
@tf.function
|
||||
def parse_item(item):
|
||||
parsed = tf.io.parse_single_example(item, features={
|
||||
"rainfallradar": tf.io.FixedLenFeature([], tf.string),
|
||||
"waterdepth": tf.io.FixedLenFeature([], tf.string)
|
||||
})
|
||||
rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32)
|
||||
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
|
||||
|
||||
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
||||
|
||||
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
||||
return rainfall, water
|
||||
|
||||
def parse_example(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5):
|
||||
return tf.data.TFRecordDataset(filenames,
|
||||
compression_type=compression_type,
|
||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||
).map(parse_item, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Convert a generated .jsonl.gz file to a .tfrecord.gz file")
|
||||
parser.add_argument("--input", "-i", help="Path to the input file to convert.", required=True)
|
||||
parser.add_argument("--output", "-o", help="Path to the output file to write to.", required=True)
|
||||
|
||||
return parser.parse_args(args=sys.argv[2:])
|
||||
return parser.parse_args(args=sys.argv[1:])
|
||||
|
||||
def convert(filepath_in, filepath_out):
|
||||
with gzip.open(filepath_in, "r") as handle, tf.io.TFRecordWriter(filepath_out) as writer:
|
||||
options = tf.io.TFRecordOptions(compression_type="GZIP", compression_level=9)
|
||||
with gzip.open(filepath_in, "r") as handle, tf.io.TFRecordWriter(filepath_out, options=options) as writer:
|
||||
i = -1
|
||||
for line in handle:
|
||||
i += 1
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
###
|
||||
## 1: Parse JSON
|
||||
###
|
||||
obj = json.loads(line)
|
||||
|
||||
rainfall = tf.constant(obj.rainfallradar, dtype=tf.float32)
|
||||
water = tf.constant(obj.waterdepth, dtype=tf.float32)
|
||||
###
|
||||
## 2: Convert to tensor
|
||||
###
|
||||
rainfall = tf.constant(obj["rainfallradar"], dtype=tf.float32)
|
||||
water = tf.constant(obj["waterdepth"], dtype=tf.float32)
|
||||
|
||||
###
|
||||
## 3: Print shape definitions (required when parsing)
|
||||
###
|
||||
if i == 0:
|
||||
print("SHAPES\t"+json.dumps({ "rainfallradar": rainfall.shape.as_list(), "waterdepth": water.shape.as_list() }))
|
||||
|
||||
###
|
||||
## 4: Serialise tensors
|
||||
###
|
||||
rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(rainfall, name="rainfall").numpy()])
|
||||
water = tf.train.BytesList(value=[tf.io.serialize_tensor(water, name="water").numpy()])
|
||||
|
||||
###
|
||||
## 5: Write to .tfrecord.gz file
|
||||
###
|
||||
record = tf.train.Example(features=tf.train.Features(feature={
|
||||
"rainfallradar": tf.train.BytesList(bytes_list=tf.io.serialize_tensor(rainfall)),
|
||||
"waterdepth": tf.train.BytesList(bytes_list=tf.io.serialize_tensor(water))
|
||||
"rainfallradar": tf.train.Feature(bytes_list=rainfall),
|
||||
"waterdepth": tf.train.Feature(bytes_list=water)
|
||||
}))
|
||||
writer.write(record.SerializeToString())
|
||||
|
||||
print(i)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -39,4 +87,10 @@ def main():
|
|||
sys.exit(2)
|
||||
|
||||
|
||||
convert(args.input, args.output)
|
||||
convert(args.input, args.output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
else:
|
||||
print("This script must be run directly. It cannot be imported.")
|
||||
exit(1)
|
||||
|
|
Loading…
Reference in a new issue