From cc6679c6095305b7fb70fa5e87c86bc8c4727f8f Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 20 Oct 2022 15:22:29 +0100 Subject: [PATCH] batch data; use generator --- aimodel/src/subcommands/train_predict.py | 39 +++++++++++++----------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/aimodel/src/subcommands/train_predict.py b/aimodel/src/subcommands/train_predict.py index 71f17b6..037441d 100644 --- a/aimodel/src/subcommands/train_predict.py +++ b/aimodel/src/subcommands/train_predict.py @@ -7,6 +7,7 @@ import re from loguru import logger import tensorflow as tf +from lib.dataset.batched_iterator import batched_iterator from lib.vis.segmentation_plot import segmentation_plot from lib.io.handle_open import handle_open @@ -87,25 +88,27 @@ def run(args): sys.stderr.write(">>> Complete\n") def do_png(args, ai, dataset, model_code): + model_params = json.loads(readfile(args.params)) + i = 0 - for rainfall, water in dataset: - water_predict = ai.embed(rainfall) - - # [ width, height, softmax_probabilities ] → [ batch, width, height ] - water_predict = tf.math.argmax(water_predict, axis=-1) - # [ width, height ] - water = tf.squeeze(water) - - segmentation_plot( - water, water_predict, - model_code, - args.output.replace("+d", str(i)) - ) - - i += 1 - - if i % 100 == 0: - sys.stderr.write(f"Processed {i} items") + for rainfall, water in batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]): + water_predict_batch = ai.embed(rainfall) + for water_predict in water_predict_batch: + # [ width, height, softmax_probabilities ] → [ batch, width, height ] + water_predict = tf.math.argmax(water_predict, axis=-1) + # [ width, height ] + water = tf.squeeze(water) + + segmentation_plot( + water, water_predict, + model_code, + args.output.replace("+d", str(i)) + ) + + i += 1 + + if i % 100 == 0: + sys.stderr.write(f"Processed {i} items") def do_jsonl(args, ai, dataset): write_mode = "wt" if args.output.endswith(".gz") else "w"