mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
get python bridge working t convert .jsonl.gz → .tfrecord.gz
This commit is contained in:
parent
28a3f578d5
commit
a02c3436ab
1 changed files with 62 additions and 8 deletions
70
rainfallwrangler/src/lib/python/json2tfrecord.py
Normal file → Executable file
70
rainfallwrangler/src/lib/python/json2tfrecord.py
Normal file → Executable file
|
@ -1,34 +1,82 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
# TO PARSE:
|
||||||
|
@tf.function
|
||||||
|
def parse_item(item):
|
||||||
|
parsed = tf.io.parse_single_example(item, features={
|
||||||
|
"rainfallradar": tf.io.FixedLenFeature([], tf.string),
|
||||||
|
"waterdepth": tf.io.FixedLenFeature([], tf.string)
|
||||||
|
})
|
||||||
|
rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32)
|
||||||
|
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
|
||||||
|
|
||||||
|
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
||||||
|
|
||||||
|
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
||||||
|
return rainfall, water
|
||||||
|
|
||||||
|
def parse_example(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5):
|
||||||
|
return tf.data.TFRecordDataset(filenames,
|
||||||
|
compression_type=compression_type,
|
||||||
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||||
|
).map(parse_item, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Convert a generated .jsonl.gz file to a .tfrecord.gz file")
|
parser = argparse.ArgumentParser(description="Convert a generated .jsonl.gz file to a .tfrecord.gz file")
|
||||||
parser.add_argument("--input", "-i", help="Path to the input file to convert.", required=True)
|
parser.add_argument("--input", "-i", help="Path to the input file to convert.", required=True)
|
||||||
parser.add_argument("--output", "-o", help="Path to the output file to write to.", required=True)
|
parser.add_argument("--output", "-o", help="Path to the output file to write to.", required=True)
|
||||||
|
return parser.parse_args(args=sys.argv[1:])
|
||||||
return parser.parse_args(args=sys.argv[2:])
|
|
||||||
|
|
||||||
def convert(filepath_in, filepath_out):
|
def convert(filepath_in, filepath_out):
|
||||||
with gzip.open(filepath_in, "r") as handle, tf.io.TFRecordWriter(filepath_out) as writer:
|
options = tf.io.TFRecordOptions(compression_type="GZIP", compression_level=9)
|
||||||
|
with gzip.open(filepath_in, "r") as handle, tf.io.TFRecordWriter(filepath_out, options=options) as writer:
|
||||||
|
i = -1
|
||||||
for line in handle:
|
for line in handle:
|
||||||
|
i += 1
|
||||||
if len(line) == 0:
|
if len(line) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
###
|
||||||
|
## 1: Parse JSON
|
||||||
|
###
|
||||||
obj = json.loads(line)
|
obj = json.loads(line)
|
||||||
|
|
||||||
rainfall = tf.constant(obj.rainfallradar, dtype=tf.float32)
|
###
|
||||||
water = tf.constant(obj.waterdepth, dtype=tf.float32)
|
## 2: Convert to tensor
|
||||||
|
###
|
||||||
|
rainfall = tf.constant(obj["rainfallradar"], dtype=tf.float32)
|
||||||
|
water = tf.constant(obj["waterdepth"], dtype=tf.float32)
|
||||||
|
|
||||||
|
###
|
||||||
|
## 3: Print shape definitions (required when parsing)
|
||||||
|
###
|
||||||
|
if i == 0:
|
||||||
|
print("SHAPES\t"+json.dumps({ "rainfallradar": rainfall.shape.as_list(), "waterdepth": water.shape.as_list() }))
|
||||||
|
|
||||||
|
###
|
||||||
|
## 4: Serialise tensors
|
||||||
|
###
|
||||||
|
rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(rainfall, name="rainfall").numpy()])
|
||||||
|
water = tf.train.BytesList(value=[tf.io.serialize_tensor(water, name="water").numpy()])
|
||||||
|
|
||||||
|
###
|
||||||
|
## 5: Write to .tfrecord.gz file
|
||||||
|
###
|
||||||
record = tf.train.Example(features=tf.train.Features(feature={
|
record = tf.train.Example(features=tf.train.Features(feature={
|
||||||
"rainfallradar": tf.train.BytesList(bytes_list=tf.io.serialize_tensor(rainfall)),
|
"rainfallradar": tf.train.Feature(bytes_list=rainfall),
|
||||||
"waterdepth": tf.train.BytesList(bytes_list=tf.io.serialize_tensor(water))
|
"waterdepth": tf.train.Feature(bytes_list=water)
|
||||||
}))
|
}))
|
||||||
writer.write(record.SerializeToString())
|
writer.write(record.SerializeToString())
|
||||||
|
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -39,4 +87,10 @@ def main():
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
|
||||||
|
|
||||||
convert(args.input, args.output)
|
convert(args.input, args.output)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
else:
|
||||||
|
print("This script must be run directly. It cannot be imported.")
|
||||||
|
exit(1)
|
||||||
|
|
Loading…
Reference in a new issue