From 1876a8883ce5627c7b4588c2b3c7c2821538039e Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 14 Sep 2022 15:12:07 +0100 Subject: [PATCH] =?UTF-8?q?ai=20pretrain-predict:=20fix=20-=20=E2=86=92=20?= =?UTF-8?q?=5F=20in=20cli=20parsing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aimodel/src/parse_args.py | 2 +- aimodel/src/subcommands/pretrain_predict.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/aimodel/src/parse_args.py b/aimodel/src/parse_args.py index f5acb72..162be03 100644 --- a/aimodel/src/parse_args.py +++ b/aimodel/src/parse_args.py @@ -25,7 +25,7 @@ For more information, do src/index.py --help. """) exit(0) - subcommand = re.sub(r'[^a-z0-9-]', '', sys.argv[1]) + subcommand = re.sub(r'-', '_', re.sub(r'[^a-z0-9-]', '', sys.argv[1])) subcommand_argparser = importlib.import_module(f"subcommands.{subcommand}").parse_args diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index a4e5cbf..03140a1 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -68,11 +68,16 @@ def run(args): if filepath_output != "-": handle = io.open(filepath_output, "w") + i = 0 for rainfall, water in ai.embed(dataset): handle.write(json.dumps({ "rainfall": rainfall.numpy().tolist(), "water": water.numpy().tolist() }, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 + + if i == 0 or i % 1000 == 0: + sys.stderr.write(f"[pretrain:predict] STEP {i}\r") + i += 1 handle.close()