write params.json properly

This commit is contained in:
Starbeamrainbowlabs 2022-09-27 17:49:54 +01:00
parent a5455dc22a
commit dbfa45a016
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -8,8 +8,8 @@ import re
from loguru import logger
import tensorflow as tf
import numpy as np
from lib.io.writefile import writefile
from lib.io.writefile import writefile
from lib.io.handle_open import handle_open
from lib.ai.RainfallWaterContraster import RainfallWaterContraster
from lib.dataset.dataset import dataset_predict
@ -76,11 +76,14 @@ def run(args):
write_mode = "wb"
handle = sys.stdout
filepath_params = None
if filepath_output != "-":
handle = handle_open(
filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0),
write_mode
)
filepath_params = os.path.join(os.path.dirname(filepath_output), "/params.json")
i = 0
i_file = i
@ -96,8 +99,11 @@ def run(args):
if output_mode == MODE_JSONL:
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
elif output_mode == MODE_TFRECORD:
if i == 0:
writefile(json.dumps({ }))
if i == 0 and filepath_params is not None:
writefile(filepath_params, json.dumps({
"rainfallradar": step_rainfall.shape.as_list(),
"waterdepth": step_water.shape.as_list()
}))
step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()])
step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])