mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
train_predict: revamp jsonl handling
This commit is contained in:
parent
8195318a42
commit
587c1dfafa
1 changed files with 49 additions and 19 deletions
|
@ -31,6 +31,7 @@ def parse_args():
|
|||
parser.add_argument("--params", "-p", help="Optional. The file containing the model hyperparameters (usually called 'params.json'). If not specified, it's location will be determined automatically.")
|
||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 0. When using this start with 1.5, which means read ceil(NUMBER_OF_CORES * 1.5). Set to a higher number of systems with high read latency to avoid starving the GPU of data. SETTING THIS WILL SCRAMBLE THE ORDER OF THE DATASET.")
|
||||
parser.add_argument("--model-code", help="A description of the model used to predict the data. Will be inserted in the title of png plots.")
|
||||
parser.add_argument("--log", help="Optional. If specified when the file extension is .jsonl[.gz], then this chooses what is logged. Specify a comma separated list of values. Possible values: rainfall_actual, water_actual, water_predict. Default: rainfall_actual,water_actual,water_predict.")
|
||||
return parser
|
||||
|
||||
def run(args):
|
||||
|
@ -47,6 +48,10 @@ def run(args):
|
|||
args.output = "-"
|
||||
if (not hasattr(args, "model_code")) or args.model_code == None:
|
||||
args.model_code = ""
|
||||
if (not hasattr(args, "log")) or args.log == None:
|
||||
args.log = "rainfall_actual,water_actual,water_predict"
|
||||
|
||||
args.log = args.log.strip().split(",")
|
||||
|
||||
if not os.path.exists(args.params):
|
||||
raise Exception(f"Error: The specified filepath params.json hyperparameters ('{args.params}) does not exist.")
|
||||
|
@ -59,7 +64,8 @@ def run(args):
|
|||
os.mkdir(dirpath_output)
|
||||
|
||||
|
||||
ai = RainfallWaterSegmenter.from_checkpoint(args.checkpoint, **json.loads(readfile(args.params)))
|
||||
model_params = json.loads(readfile(args.params))
|
||||
ai = RainfallWaterSegmenter.from_checkpoint(args.checkpoint, **model_params)
|
||||
|
||||
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||||
|
||||
|
@ -81,32 +87,30 @@ def run(args):
|
|||
logger.info(f"Records per file: {args.records_per_file}")
|
||||
|
||||
if output_mode == MODE_JSONL:
|
||||
do_jsonl(args, ai, dataset)
|
||||
do_jsonl(args, ai, dataset, args.model_code, model_params)
|
||||
else:
|
||||
do_png(args, ai, dataset, args.model_code)
|
||||
do_png(args, ai, dataset, args.model_code, model_params)
|
||||
|
||||
sys.stderr.write(">>> Complete\n")
|
||||
|
||||
def do_png(args, ai, dataset, model_code):
|
||||
def do_png(args, ai, dataset, model_code, model_params):
|
||||
if not os.path.exists(os.path.dirname(args.output)):
|
||||
os.mkdir(os.path.dirname(args.output))
|
||||
|
||||
model_params = json.loads(readfile(args.params))
|
||||
|
||||
i = 0
|
||||
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
|
||||
for item in gen:
|
||||
rainfall, water = item
|
||||
|
||||
water_predict_batch = ai.embed(rainfall)
|
||||
water = tf.squeeze(tf.unstack(water, axis=0))
|
||||
water = tf.unstack(water, axis=0)
|
||||
|
||||
i_batch = 0
|
||||
for water_predict in water_predict_batch:
|
||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||
# [ width, height ]
|
||||
water_actual = water[i_batch]
|
||||
water_actual = tf.squeeze(water[i_batch])
|
||||
|
||||
segmentation_plot(
|
||||
water_actual, water_predict,
|
||||
|
@ -120,7 +124,7 @@ def do_png(args, ai, dataset, model_code):
|
|||
if i % 100 == 0:
|
||||
sys.stderr.write(f"Processed {i} items\r")
|
||||
|
||||
def do_jsonl(args, ai, dataset):
|
||||
def do_jsonl(args, ai, dataset, model_params):
|
||||
write_mode = "wt" if args.output.endswith(".gz") else "w"
|
||||
|
||||
handle = sys.stdout
|
||||
|
@ -138,18 +142,44 @@ def do_jsonl(args, ai, dataset):
|
|||
i = 0
|
||||
i_file = i
|
||||
files_done = 0
|
||||
for step_rainfall, step_water in ai.embed(dataset):
|
||||
if args.records_per_file > 0 and i_file > args.records_per_file:
|
||||
files_done += 1
|
||||
i_file = 0
|
||||
handle.close()
|
||||
logger.info(f"PROGRESS:file {files_done}")
|
||||
handle = handle_open(args.output.replace("+d", str(files_done+1)), write_mode)
|
||||
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]):
|
||||
rainfall_actual_batch, water_actual_batch = batch
|
||||
|
||||
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||||
water_predict_batch = ai.embed(rainfall_actual_batch)
|
||||
water_actual_batch = tf.unstack(water_actual_batch, axis=0)
|
||||
rainfall_actual_batch = tf.unstack(rainfall_actual_batch, axis=0)
|
||||
|
||||
i_batch = 0
|
||||
for water_predict in water_predict_batch:
|
||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||
# [ width, height ]
|
||||
water_actual = tf.squeeze(water_actual_batch[i_batch])
|
||||
|
||||
if args.records_per_file > 0 and i_file > args.records_per_file:
|
||||
files_done += 1
|
||||
i_file = 0
|
||||
handle.close()
|
||||
logger.info(f"PROGRESS:file {files_done}")
|
||||
handle = handle_open(args.output.replace("+d", str(files_done+1)), write_mode)
|
||||
|
||||
item_obj = {}
|
||||
if "rainfall_actual" in args.log:
|
||||
item_obj["rainfall_actual"] = rainfall_actual_batch[i_batch].numpy().list()
|
||||
if "water_actual" in args.log:
|
||||
item_obj["water_actual"] = water_actual.numpy().list()
|
||||
if "water_predict" in args.log:
|
||||
item_obj["water_predict"] = water_predict.numpy().list()
|
||||
|
||||
handle.write(json.dumps(item_obj, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||||
|
||||
if i == 0 or i % 100 == 0:
|
||||
sys.stderr.write(f"[pretrain:predict] STEP {i}\r")
|
||||
|
||||
|
||||
i_batch += 1
|
||||
|
||||
|
||||
if i == 0 or i % 100 == 0:
|
||||
sys.stderr.write(f"[pretrain:predict] STEP {i}\r")
|
||||
|
||||
i += 1
|
||||
i_file += 1
|
||||
|
|
Loading…
Reference in a new issue