diff --git a/aimodel/src/subcommands/pretrain.py b/aimodel/src/subcommands/pretrain.py index 2d35fe6..ec947b7 100644 --- a/aimodel/src/subcommands/pretrain.py +++ b/aimodel/src/subcommands/pretrain.py @@ -53,7 +53,7 @@ def run(args): feature_dim=args.feature_dim, shape_rainfall=dataset_metadata["rainfallradar"], - shape_water=[ math.floor(value * 0.75) for value in dataset_metadata["waterdepth"] ] + shape_water=[ math.ceil(value * 0.75) for value in dataset_metadata["waterdepth"] ] ) ai.train(dataset_train, dataset_validate)