train_predict: revamp jsonl handling

This commit is contained in:
Starbeamrainbowlabs 2022-10-21 16:53:08 +01:00
parent 8195318a42
commit 587c1dfafa
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -31,6 +31,7 @@ def parse_args():
parser.add_argument("--params", "-p", help="Optional. The file containing the model hyperparameters (usually called 'params.json'). If not specified, it's location will be determined automatically.")
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 0. When using this start with 1.5, which means read ceil(NUMBER_OF_CORES * 1.5). Set to a higher number of systems with high read latency to avoid starving the GPU of data. SETTING THIS WILL SCRAMBLE THE ORDER OF THE DATASET.")
parser.add_argument("--model-code", help="A description of the model used to predict the data. Will be inserted in the title of png plots.")
parser.add_argument("--log", help="Optional. If specified when the file extension is .jsonl[.gz], then this chooses what is logged. Specify a comma separated list of values. Possible values: rainfall_actual, water_actual, water_predict. Default: rainfall_actual,water_actual,water_predict.")
return parser
def run(args):
@ -47,6 +48,10 @@ def run(args):
args.output = "-"
if (not hasattr(args, "model_code")) or args.model_code == None:
args.model_code = ""
if (not hasattr(args, "log")) or args.log == None:
args.log = "rainfall_actual,water_actual,water_predict"
args.log = args.log.strip().split(",")
if not os.path.exists(args.params):
raise Exception(f"Error: The specified filepath params.json hyperparameters ('{args.params}) does not exist.")
@ -59,7 +64,8 @@ def run(args):
os.mkdir(dirpath_output)
ai = RainfallWaterSegmenter.from_checkpoint(args.checkpoint, **json.loads(readfile(args.params)))
model_params = json.loads(readfile(args.params))
ai = RainfallWaterSegmenter.from_checkpoint(args.checkpoint, **model_params)
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
@ -81,32 +87,30 @@ def run(args):
logger.info(f"Records per file: {args.records_per_file}")
if output_mode == MODE_JSONL:
do_jsonl(args, ai, dataset)
do_jsonl(args, ai, dataset, args.model_code, model_params)
else:
do_png(args, ai, dataset, args.model_code)
do_png(args, ai, dataset, args.model_code, model_params)
sys.stderr.write(">>> Complete\n")
def do_png(args, ai, dataset, model_code):
def do_png(args, ai, dataset, model_code, model_params):
if not os.path.exists(os.path.dirname(args.output)):
os.mkdir(os.path.dirname(args.output))
model_params = json.loads(readfile(args.params))
i = 0
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
for item in gen:
rainfall, water = item
water_predict_batch = ai.embed(rainfall)
water = tf.squeeze(tf.unstack(water, axis=0))
water = tf.unstack(water, axis=0)
i_batch = 0
for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ]
water_actual = water[i_batch]
water_actual = tf.squeeze(water[i_batch])
segmentation_plot(
water_actual, water_predict,
@ -120,7 +124,7 @@ def do_png(args, ai, dataset, model_code):
if i % 100 == 0:
sys.stderr.write(f"Processed {i} items\r")
def do_jsonl(args, ai, dataset):
def do_jsonl(args, ai, dataset, model_params):
write_mode = "wt" if args.output.endswith(".gz") else "w"
handle = sys.stdout
@ -138,18 +142,44 @@ def do_jsonl(args, ai, dataset):
i = 0
i_file = i
files_done = 0
for step_rainfall, step_water in ai.embed(dataset):
if args.records_per_file > 0 and i_file > args.records_per_file:
files_done += 1
i_file = 0
handle.close()
logger.info(f"PROGRESS:file {files_done}")
handle = handle_open(args.output.replace("+d", str(files_done+1)), write_mode)
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]):
rainfall_actual_batch, water_actual_batch = batch
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
water_predict_batch = ai.embed(rainfall_actual_batch)
water_actual_batch = tf.unstack(water_actual_batch, axis=0)
rainfall_actual_batch = tf.unstack(rainfall_actual_batch, axis=0)
i_batch = 0
for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ]
water_actual = tf.squeeze(water_actual_batch[i_batch])
if args.records_per_file > 0 and i_file > args.records_per_file:
files_done += 1
i_file = 0
handle.close()
logger.info(f"PROGRESS:file {files_done}")
handle = handle_open(args.output.replace("+d", str(files_done+1)), write_mode)
item_obj = {}
if "rainfall_actual" in args.log:
item_obj["rainfall_actual"] = rainfall_actual_batch[i_batch].numpy().list()
if "water_actual" in args.log:
item_obj["water_actual"] = water_actual.numpy().list()
if "water_predict" in args.log:
item_obj["water_predict"] = water_predict.numpy().list()
handle.write(json.dumps(item_obj, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
if i == 0 or i % 100 == 0:
sys.stderr.write(f"[pretrain:predict] STEP {i}\r")
i_batch += 1
if i == 0 or i % 100 == 0:
sys.stderr.write(f"[pretrain:predict] STEP {i}\r")
i += 1
i_file += 1