train_predict: fixup

This commit is contained in:
Starbeamrainbowlabs 2022-10-20 15:42:33 +01:00
parent cc6679c609
commit aed2348a95
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -91,7 +91,8 @@ def do_png(args, ai, dataset, model_code):
model_params = json.loads(readfile(args.params))
i = 0
for rainfall, water in batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]):
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
for rainfall, water in gen:
water_predict_batch = ai.embed(rainfall)
for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ]