mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
batch data; use generator
This commit is contained in:
parent
d306853c42
commit
cc6679c609
1 changed files with 21 additions and 18 deletions
|
@ -7,6 +7,7 @@ import re
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from lib.dataset.batched_iterator import batched_iterator
|
||||||
|
|
||||||
from lib.vis.segmentation_plot import segmentation_plot
|
from lib.vis.segmentation_plot import segmentation_plot
|
||||||
from lib.io.handle_open import handle_open
|
from lib.io.handle_open import handle_open
|
||||||
|
@ -87,25 +88,27 @@ def run(args):
|
||||||
sys.stderr.write(">>> Complete\n")
|
sys.stderr.write(">>> Complete\n")
|
||||||
|
|
||||||
def do_png(args, ai, dataset, model_code):
|
def do_png(args, ai, dataset, model_code):
|
||||||
|
model_params = json.loads(readfile(args.params))
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for rainfall, water in dataset:
|
for rainfall, water in batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]):
|
||||||
water_predict = ai.embed(rainfall)
|
water_predict_batch = ai.embed(rainfall)
|
||||||
|
for water_predict in water_predict_batch:
|
||||||
|
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||||
|
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||||
|
# [ width, height ]
|
||||||
|
water = tf.squeeze(water)
|
||||||
|
|
||||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
segmentation_plot(
|
||||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
water, water_predict,
|
||||||
# [ width, height ]
|
model_code,
|
||||||
water = tf.squeeze(water)
|
args.output.replace("+d", str(i))
|
||||||
|
)
|
||||||
|
|
||||||
segmentation_plot(
|
i += 1
|
||||||
water, water_predict,
|
|
||||||
model_code,
|
|
||||||
args.output.replace("+d", str(i))
|
|
||||||
)
|
|
||||||
|
|
||||||
i += 1
|
if i % 100 == 0:
|
||||||
|
sys.stderr.write(f"Processed {i} items")
|
||||||
if i % 100 == 0:
|
|
||||||
sys.stderr.write(f"Processed {i} items")
|
|
||||||
|
|
||||||
def do_jsonl(args, ai, dataset):
|
def do_jsonl(args, ai, dataset):
|
||||||
write_mode = "wt" if args.output.endswith(".gz") else "w"
|
write_mode = "wt" if args.output.endswith(".gz") else "w"
|
||||||
|
|
Loading…
Reference in a new issue