write params.json properly

This commit is contained in:
Starbeamrainbowlabs 2022-09-27 17:49:54 +01:00
parent a5455dc22a
commit dbfa45a016
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -8,8 +8,8 @@ import re
from loguru import logger from loguru import logger
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from lib.io.writefile import writefile
from lib.io.writefile import writefile
from lib.io.handle_open import handle_open from lib.io.handle_open import handle_open
from lib.ai.RainfallWaterContraster import RainfallWaterContraster from lib.ai.RainfallWaterContraster import RainfallWaterContraster
from lib.dataset.dataset import dataset_predict from lib.dataset.dataset import dataset_predict
@ -76,11 +76,14 @@ def run(args):
write_mode = "wb" write_mode = "wb"
handle = sys.stdout handle = sys.stdout
filepath_params = None
if filepath_output != "-": if filepath_output != "-":
handle = handle_open( handle = handle_open(
filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0), filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0),
write_mode write_mode
) )
filepath_params = os.path.join(os.path.dirname(filepath_output), "/params.json")
i = 0 i = 0
i_file = i i_file = i
@ -96,8 +99,11 @@ def run(args):
if output_mode == MODE_JSONL: if output_mode == MODE_JSONL:
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
elif output_mode == MODE_TFRECORD: elif output_mode == MODE_TFRECORD:
if i == 0: if i == 0 and filepath_params is not None:
writefile(json.dumps({ })) writefile(filepath_params, json.dumps({
"rainfallradar": step_rainfall.shape.as_list(),
"waterdepth": step_water.shape.as_list()
}))
step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()]) step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()])
step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()]) step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])