From ead80094253456e0b8a7395b04bdc54dd9416c8d Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Mon, 5 Sep 2022 15:36:40 +0100 Subject: [PATCH] pretrain: add CLI arg for size of watch prediction width/height --- aimodel/src/subcommands/pretrain.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aimodel/src/subcommands/pretrain.py b/aimodel/src/subcommands/pretrain.py index 91bc5a5..6d5d199 100644 --- a/aimodel/src/subcommands/pretrain.py +++ b/aimodel/src/subcommands/pretrain.py @@ -17,11 +17,13 @@ def parse_args(): parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 2048].", type=int) parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int) parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.") + parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int) return parser def run(args): - + if (not hasattr(args, "water_size")) or args.water_size == None: + args.water_size = 100 if (not hasattr(args, "batch_size")) or args.batch_size == None: args.batch_size = 64 if (not hasattr(args, "feature_dim")) or args.feature_dim == None: @@ -53,7 +55,7 @@ def run(args): feature_dim=args.feature_dim, metadata = read_metadata(args.input), - shape_water=[ 100, 100 ] # The DESIRED + shape_water=[ args.shape_water, args.shape_water ] # The DESIRED output shape. the actual data will be cropped to match this. ) ai.train(dataset_train, dataset_validate)