Bugfix: modeset to enable TFRecordWriter instead of bare handle

This commit is contained in:
Starbeamrainbowlabs 2022-10-06 20:07:59 +01:00
parent e9a8e2eb57
commit f883986eaa
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -33,6 +33,13 @@ def parse_args():
return parser return parser
def handle_open_modeset(filepath, write_mode, handle_mode):
if handle_mode == MODE_TFRECORD:
options = tf.io.TFRecordOptions(compression_type="GZIP", compression_level=9) if filepath.endswith(".gz") else tf.io.TFRecordOptions()
return tf.io.TFRecordWriter(filepath, options=options)
else:
return handle_open(filepath, write_mode)
def run(args): def run(args):
# Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/ # Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/
@ -81,9 +88,10 @@ def run(args):
handle = sys.stdout handle = sys.stdout
filepath_metadata = None filepath_metadata = None
if filepath_output != "-": if filepath_output != "-":
handle = handle_open( handle = handle_open_modeset(
filepath_output if args.records_per_file <= 0 else filepath_output.replace("+d", str(0)), filepath_output if args.records_per_file <= 0 else filepath_output.replace("+d", str(0)),
write_mode write_mode=write_mode,
handle_mode=output_mode
) )
filepath_metadata = os.path.join(os.path.dirname(filepath_output), "metadata.json") filepath_metadata = os.path.join(os.path.dirname(filepath_output), "metadata.json")
@ -99,7 +107,7 @@ def run(args):
i_file = 0 i_file = 0
handle.close() handle.close()
logger.info(f"PROGRESS:file {files_done}") logger.info(f"PROGRESS:file {files_done}")
handle = handle_open(filepath_output.replace("+d", str(files_done+1)), write_mode) handle = handle_open_modeset(filepath_output.replace("+d", str(files_done+1)), write_mode, hand_mode=output_mode)
if output_mode == MODE_JSONL: if output_mode == MODE_JSONL:
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422