diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index 9577aa3..2a3694f 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -8,8 +8,8 @@ import re from loguru import logger import tensorflow as tf import numpy as np -from lib.io.writefile import writefile +from lib.io.writefile import writefile from lib.io.handle_open import handle_open from lib.ai.RainfallWaterContraster import RainfallWaterContraster from lib.dataset.dataset import dataset_predict @@ -76,11 +76,14 @@ def run(args): write_mode = "wb" handle = sys.stdout + filepath_params = None if filepath_output != "-": handle = handle_open( filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0), write_mode ) + filepath_params = os.path.join(os.path.dirname(filepath_output), "/params.json") + i = 0 i_file = i @@ -96,8 +99,11 @@ def run(args): if output_mode == MODE_JSONL: handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 elif output_mode == MODE_TFRECORD: - if i == 0: - writefile(json.dumps({ })) + if i == 0 and filepath_params is not None: + writefile(filepath_params, json.dumps({ + "rainfallradar": step_rainfall.shape.as_list(), + "waterdepth": step_water.shape.as_list() + })) step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()]) step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])