From d6ff3fb2ce62deb480b399fc0e0bc44ae488d613 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 27 Sep 2022 17:38:12 +0100 Subject: [PATCH] pretrain_predict fix write mode --- aimodel/src/lib/ai/RainfallWaterSegmentor.py | 2 +- aimodel/src/subcommands/pretrain_predict.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/aimodel/src/lib/ai/RainfallWaterSegmentor.py b/aimodel/src/lib/ai/RainfallWaterSegmentor.py index 72c5393..5419d0b 100644 --- a/aimodel/src/lib/ai/RainfallWaterSegmentor.py +++ b/aimodel/src/lib/ai/RainfallWaterSegmentor.py @@ -88,7 +88,7 @@ class RainfallWaterSegmenter(object): i_batch = -1 for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size): i_batch += 1 - rainfall = self.model(batch[0], training=False) # ((rainfall, water), dummy_label) + rainfall = self.model(batch[0], training=False) # (rainfall_embed, water) for step in tf.unstack(rainfall, axis=0): yield step diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index 9d3cad5..716ebe3 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -8,6 +8,7 @@ import re from loguru import logger import tensorflow as tf import numpy as np +from aimodel.src.lib.io.writefile import writefile from lib.io.handle_open import handle_open from lib.ai.RainfallWaterContraster import RainfallWaterContraster @@ -70,11 +71,15 @@ def run(args): output_mode = MODE_TFRECORD if filepath_output.endswith(".tfrecord") or filepath_output.endswith(".tfrecord.gz") else MODE_JSONL + write_mode = "wt" if filepath_output.endswith(".gz") else "w" + if output_mode == MODE_TFRECORD: + write_mode = "wb" + handle = sys.stdout if filepath_output != "-": handle = handle_open( filepath_output if args.records_per_file <= 0 else filepath_output.replace("$d", 0), - "wt" if filepath_output.endswith(".gz") else "w" + write_mode ) i = 0 @@ -86,11 +91,13 @@ def run(args): i_file = 0 handle.close() logger.write(f"PROGRESS:file {files_done}") - handle = handle_open(filepath_output.replace("$d", str(files_done+1))) + handle = handle_open(filepath_output.replace("$d", str(files_done+1)), write_mode) if output_mode == MODE_JSONL: handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 elif output_mode == MODE_TFRECORD: + if i == 0: + writefile(json.dumps({ })) step_rainfall = tf.train.BytesList(value=[tf.io.serialize_tensor(step_rainfall, name="rainfall").numpy()]) step_water = tf.train.BytesList(value=[tf.io.serialize_tensor(step_water, name="water").numpy()])