pretrain_predict fix write mode

This commit is contained in:
Starbeamrainbowlabs 2022-09-27 17:38:12 +01:00
parent f95fd8f9e4
commit d6ff3fb2ce
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 10 additions and 3 deletions

View file

@ -88,7 +88,7 @@ class RainfallWaterSegmenter(object):
i_batch = -1 i_batch = -1
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size): for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
i_batch += 1 i_batch += 1
rainfall = self.model(batch[0], training=False) # ((rainfall, water), dummy_label) rainfall = self.model(batch[0], training=False) # (rainfall_embed, water)
for step in tf.unstack(rainfall, axis=0): for step in tf.unstack(rainfall, axis=0):
yield step yield step

View file

@ -8,6 +8,7 @@ import re
from loguru import logger from loguru import logger
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from aimodel.src.lib.io.writefile import writefile
from lib.io.handle_open import handle_open from lib.io.handle_open import handle_open
from lib.ai.RainfallWaterContraster import RainfallWaterContraster from lib.ai.RainfallWaterContraster import RainfallWaterContraster
@ -70,11 +71,15 @@ def run(args):
output_mode = MODE_TFRECORD if filepath_output.endswith(".tfrecord") or filepath_output.endswith(".tfrecord.gz") else MODE_JSONL output_mode = MODE_TFRECORD if filepath_output.endswith(".tfrecord") or filepath_output.endswith(".tfrecord.gz") else MODE_JSONL
write_mode = "wt" if filepath_output.endswith(".gz") else "w"
if output_mode == MODE_TFRECORD:
write_mode = "wb"
handle = sys.stdout handle = sys.stdout
if filepath_output != "-": if filepath_output != "-":
handle = handle_open( handle = handle_open(
filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0), filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0),
"wt" if filepath_output.endswith(".gz") else "w" write_mode
) )
i = 0 i = 0
@ -86,11 +91,13 @@ def run(args):
i_file = 0 i_file = 0
handle.close() handle.close()
logger.write(f"PROGRESS:file {files_done}") logger.write(f"PROGRESS:file {files_done}")
handle = handle_open(filepath_output.replace("$d", str(files_done+1))) handle = handle_open(filepath_output.replace("$d", str(files_done+1)), write_mode)
if output_mode == MODE_JSONL: if output_mode == MODE_JSONL:
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
elif output_mode == MODE_TFRECORD: elif output_mode == MODE_TFRECORD:
if i == 0:
writefile(json.dumps({ }))
step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()]) step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()])
step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()]) step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])