batch data; use generator

This commit is contained in:
Starbeamrainbowlabs 2022-10-20 15:22:29 +01:00
parent d306853c42
commit cc6679c609
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -7,6 +7,7 @@ import re
from loguru import logger from loguru import logger
import tensorflow as tf import tensorflow as tf
from lib.dataset.batched_iterator import batched_iterator
from lib.vis.segmentation_plot import segmentation_plot from lib.vis.segmentation_plot import segmentation_plot
from lib.io.handle_open import handle_open from lib.io.handle_open import handle_open
@ -87,25 +88,27 @@ def run(args):
sys.stderr.write(">>> Complete\n") sys.stderr.write(">>> Complete\n")
def do_png(args, ai, dataset, model_code): def do_png(args, ai, dataset, model_code):
model_params = json.loads(readfile(args.params))
i = 0 i = 0
for rainfall, water in dataset: for rainfall, water in batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]):
water_predict = ai.embed(rainfall) water_predict_batch = ai.embed(rainfall)
for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ] # [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1) water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ] # [ width, height ]
water = tf.squeeze(water) water = tf.squeeze(water)
segmentation_plot( segmentation_plot(
water, water_predict, water, water_predict,
model_code, model_code,
args.output.replace("+d", str(i)) args.output.replace("+d", str(i))
) )
i += 1 i += 1
if i % 100 == 0: if i % 100 == 0:
sys.stderr.write(f"Processed {i} items") sys.stderr.write(f"Processed {i} items")
def do_jsonl(args, ai, dataset): def do_jsonl(args, ai, dataset):
write_mode = "wt" if args.output.endswith(".gz") else "w" write_mode = "wt" if args.output.endswith(".gz") else "w"