diff --git a/aimodel/src/lib/ai/RainfallWaterContraster.py b/aimodel/src/lib/ai/RainfallWaterContraster.py index b177e95..e790aaa 100644 --- a/aimodel/src/lib/ai/RainfallWaterContraster.py +++ b/aimodel/src/lib/ai/RainfallWaterContraster.py @@ -51,10 +51,7 @@ class RainfallWaterContraster(object): } @staticmethod - def from_checkpoint(filepath_checkpoint, filepath_hyperparams=None): - if not filepath_checkpoint: - filepath_hyperparams = find_paramsjson(filepath_checkpoint) - hyperparams = json.loads(readfile(filepath_hyperparams)) + def from_checkpoint(filepath_checkpoint, **hyperparams): return RainfallWaterContraster(filepath_checkpoint=filepath_checkpoint, **hyperparams) diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index 6b8af8c..29cfecf 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -47,7 +47,7 @@ def run(args): filepath_output = args.output if hasattr(args, "output") and args.output != None else "-" - ai = RainfallWaterContraster.from_checkpoint(args.checkpoint) + ai = RainfallWaterContraster.from_checkpoint(args.checkpoint, **json.loads(readfile(args.params))) sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")