diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index 57f6335..3e18d1f 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -33,6 +33,13 @@ def parse_args(): return parser +def handle_open_modeset(filepath, write_mode, handle_mode): + if handle_mode == MODE_TFRECORD: + options = tf.io.TFRecordOptions(compression_type="GZIP", compression_level=9) if filepath.endswith(".gz") else tf.io.TFRecordOptions() + return tf.io.TFRecordWriter(filepath, options=options) + else: + return handle_open(filepath, write_mode) + def run(args): # Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/ @@ -81,9 +88,10 @@ def run(args): handle = sys.stdout filepath_metadata = None if filepath_output != "-": - handle = handle_open( + handle = handle_open_modeset( filepath_output if args.records_per_file <= 0 else filepath_output.replace("+d", str(0)), - write_mode + write_mode=write_mode, + handle_mode=output_mode ) filepath_metadata = os.path.join(os.path.dirname(filepath_output), "metadata.json") @@ -99,7 +107,7 @@ def run(args): i_file = 0 handle.close() logger.info(f"PROGRESS:file {files_done}") - handle = handle_open(filepath_output.replace("+d", str(files_done+1)), write_mode) + handle = handle_open_modeset(filepath_output.replace("+d", str(files_done+1)), write_mode, hand_mode=output_mode) if output_mode == MODE_JSONL: handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422