mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
pretrain_predict fix write mode
This commit is contained in:
parent
f95fd8f9e4
commit
d6ff3fb2ce
2 changed files with 10 additions and 3 deletions
|
@ -88,7 +88,7 @@ class RainfallWaterSegmenter(object):
|
||||||
i_batch = -1
|
i_batch = -1
|
||||||
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
|
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
|
||||||
i_batch += 1
|
i_batch += 1
|
||||||
rainfall = self.model(batch[0], training=False) # ((rainfall, water), dummy_label)
|
rainfall = self.model(batch[0], training=False) # (rainfall_embed, water)
|
||||||
|
|
||||||
for step in tf.unstack(rainfall, axis=0):
|
for step in tf.unstack(rainfall, axis=0):
|
||||||
yield step
|
yield step
|
||||||
|
|
|
@ -8,6 +8,7 @@ import re
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from aimodel.src.lib.io.writefile import writefile
|
||||||
|
|
||||||
from lib.io.handle_open import handle_open
|
from lib.io.handle_open import handle_open
|
||||||
from lib.ai.RainfallWaterContraster import RainfallWaterContraster
|
from lib.ai.RainfallWaterContraster import RainfallWaterContraster
|
||||||
|
@ -70,11 +71,15 @@ def run(args):
|
||||||
|
|
||||||
output_mode = MODE_TFRECORD if filepath_output.endswith(".tfrecord") or filepath_output.endswith(".tfrecord.gz") else MODE_JSONL
|
output_mode = MODE_TFRECORD if filepath_output.endswith(".tfrecord") or filepath_output.endswith(".tfrecord.gz") else MODE_JSONL
|
||||||
|
|
||||||
|
write_mode = "wt" if filepath_output.endswith(".gz") else "w"
|
||||||
|
if output_mode == MODE_TFRECORD:
|
||||||
|
write_mode = "wb"
|
||||||
|
|
||||||
handle = sys.stdout
|
handle = sys.stdout
|
||||||
if filepath_output != "-":
|
if filepath_output != "-":
|
||||||
handle = handle_open(
|
handle = handle_open(
|
||||||
filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0),
|
filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0),
|
||||||
"wt" if filepath_output.endswith(".gz") else "w"
|
write_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
|
@ -86,11 +91,13 @@ def run(args):
|
||||||
i_file = 0
|
i_file = 0
|
||||||
handle.close()
|
handle.close()
|
||||||
logger.write(f"PROGRESS:file {files_done}")
|
logger.write(f"PROGRESS:file {files_done}")
|
||||||
handle = handle_open(filepath_output.replace("$d", str(files_done+1)))
|
handle = handle_open(filepath_output.replace("$d", str(files_done+1)), write_mode)
|
||||||
|
|
||||||
if output_mode == MODE_JSONL:
|
if output_mode == MODE_JSONL:
|
||||||
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||||||
elif output_mode == MODE_TFRECORD:
|
elif output_mode == MODE_TFRECORD:
|
||||||
|
if i == 0:
|
||||||
|
writefile(json.dumps({ }))
|
||||||
step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()])
|
step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()])
|
||||||
step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])
|
step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue