mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
ai: implement batched_iterator to replace .batch()
...apparently .batch() means you get a BatchedDataset or whatever when you iterate it like a tf.function instead of the actual tensor :-/
This commit is contained in:
parent
ccd256c00a
commit
bd64986332
3 changed files with 29 additions and 7 deletions
|
@ -4,6 +4,8 @@ import json
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from ..dataset.batched_iterator import batched_iterator
|
||||||
|
|
||||||
from ..io.find_paramsjson import find_paramsjson
|
from ..io.find_paramsjson import find_paramsjson
|
||||||
from ..io.readfile import readfile
|
from ..io.readfile import readfile
|
||||||
from ..io.writefile import writefile
|
from ..io.writefile import writefile
|
||||||
|
@ -86,9 +88,9 @@ class RainfallWaterContraster(object):
|
||||||
|
|
||||||
def embed(self, dataset):
|
def embed(self, dataset):
|
||||||
i_batch = -1
|
i_batch = -1
|
||||||
for batch in dataset:
|
for batch in batched_iterator(dataset, batch_size=self.batch_size):
|
||||||
i_batch += 1
|
i_batch += 1
|
||||||
rainfall = self.model_predict.predict_on_batch(batch[0]) # ((rainfall, water), dummy_label)
|
rainfall = self.model_predict.predict(batch[0]) # ((rainfall, water), dummy_label)
|
||||||
|
|
||||||
for step in tf.unstack(rainfall, axis=0):
|
for step in tf.unstack(rainfall, axis=0):
|
||||||
yield step
|
yield step
|
||||||
|
|
19
aimodel/src/lib/dataset/batched_iterator.py
Normal file
19
aimodel/src/lib/dataset/batched_iterator.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def batched_iterator(dataset, tensors_in_item=1, batch_size=64):
|
||||||
|
acc = [ [] for _ in range(tensors_in_item) ]
|
||||||
|
i_item = 0
|
||||||
|
for item in dataset:
|
||||||
|
i_item += 1
|
||||||
|
|
||||||
|
if tensors_in_item == 1:
|
||||||
|
item = [ item ]
|
||||||
|
|
||||||
|
for i_tensor, tensor in item:
|
||||||
|
acc[i_tensor].append(tensor)
|
||||||
|
|
||||||
|
if i_item >= batch_size:
|
||||||
|
yield [ tf.stack(tensors) for tensors in acc ]
|
||||||
|
for arr in acc:
|
||||||
|
arr.clear()
|
|
@ -54,9 +54,10 @@ def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression
|
||||||
compression_type=compression_type,
|
compression_type=compression_type,
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||||
).shuffle(shuffle_buffer_size) \
|
).shuffle(shuffle_buffer_size) \
|
||||||
.map(parse_item(metadata, shape_water_desired=shape_watch_desired, dummy_label=dummy_label), num_parallel_calls=tf.data.AUTOTUNE) \
|
.map(parse_item(metadata, shape_water_desired=shape_watch_desired, dummy_label=dummy_label), num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
.batch(batch_size, drop_remainder=True)
|
|
||||||
|
if batch_size != None:
|
||||||
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||||
if prefetch:
|
if prefetch:
|
||||||
dataset = dataset.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
dataset = dataset.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
@ -84,7 +85,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
||||||
|
|
||||||
return dataset_train, dataset_validate #, filepaths
|
return dataset_train, dataset_validate #, filepaths
|
||||||
|
|
||||||
def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5, prefetch=False):
|
def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True):
|
||||||
filepaths = get_filepaths(dirpath_input)
|
filepaths = get_filepaths(dirpath_input)
|
||||||
filepaths_count = len(filepaths)
|
filepaths_count = len(filepaths)
|
||||||
for i in range(len(filepaths)):
|
for i in range(len(filepaths)):
|
||||||
|
@ -93,8 +94,8 @@ def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5,
|
||||||
return make_dataset(
|
return make_dataset(
|
||||||
filepaths=filepaths,
|
filepaths=filepaths,
|
||||||
metadata=read_metadata(dirpath_input),
|
metadata=read_metadata(dirpath_input),
|
||||||
batch_size=batch_size,
|
|
||||||
parallel_reads_multiplier=parallel_reads_multiplier,
|
parallel_reads_multiplier=parallel_reads_multiplier,
|
||||||
|
batch_size=None,
|
||||||
dummy_label=False,
|
dummy_label=False,
|
||||||
prefetch=prefetch
|
prefetch=prefetch
|
||||||
), filepaths[0:filepaths_count], filepaths_count
|
), filepaths[0:filepaths_count], filepaths_count
|
||||||
|
|
Loading…
Reference in a new issue