pretrain: add CLI arg for size of watch prediction width/height

This commit is contained in:
Starbeamrainbowlabs 2022-09-05 15:36:40 +01:00
parent 9d39215dd5
commit ead8009425
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -17,11 +17,13 @@ def parse_args():
parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 2048].", type=int)
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
return parser
def run(args):
if (not hasattr(args, "water_size")) or args.water_size == None:
args.water_size = 100
if (not hasattr(args, "batch_size")) or args.batch_size == None:
args.batch_size = 64
if (not hasattr(args, "feature_dim")) or args.feature_dim == None:
@ -53,7 +55,7 @@ def run(args):
feature_dim=args.feature_dim,
metadata = read_metadata(args.input),
shape_water=[ 100, 100 ] # The DESIRED
shape_water=[ args.shape_water, args.shape_water ] # The DESIRED output shape. the actual data will be cropped to match this.
)
ai.train(dataset_train, dataset_validate)