From 1e682661dbf21cefec9f09564fe0b033799f0be3 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 14 Sep 2022 17:11:06 +0100 Subject: [PATCH] ai: kwargs in from_checkpoint --- aimodel/src/lib/ai/RainfallWaterContraster.py | 5 +---- aimodel/src/subcommands/pretrain_predict.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/aimodel/src/lib/ai/RainfallWaterContraster.py b/aimodel/src/lib/ai/RainfallWaterContraster.py index b177e95..e790aaa 100644 --- a/aimodel/src/lib/ai/RainfallWaterContraster.py +++ b/aimodel/src/lib/ai/RainfallWaterContraster.py @@ -51,10 +51,7 @@ class RainfallWaterContraster(object): } @staticmethod - def from_checkpoint(filepath_checkpoint, filepath_hyperparams=None): - if not filepath_checkpoint: - filepath_hyperparams = find_paramsjson(filepath_checkpoint) - hyperparams = json.loads(readfile(filepath_hyperparams)) + def from_checkpoint(filepath_checkpoint, **hyperparams): return RainfallWaterContraster(filepath_checkpoint=filepath_checkpoint, **hyperparams) diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index 6b8af8c..29cfecf 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -47,7 +47,7 @@ def run(args): filepath_output = args.output if hasattr(args, "output") and args.output != None else "-" - ai = RainfallWaterContraster.from_checkpoint(args.checkpoint) + ai = RainfallWaterContraster.from_checkpoint(args.checkpoint, **json.loads(readfile(args.params))) sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")