mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
pretrain: add CLI arg for size of watch prediction width/height
This commit is contained in:
parent
9d39215dd5
commit
ead8009425
1 changed files with 4 additions and 2 deletions
|
@ -17,11 +17,13 @@ def parse_args():
|
|||
parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 2048].", type=int)
|
||||
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
|
||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
|
||||
|
||||
return parser
|
||||
|
||||
def run(args):
|
||||
|
||||
if (not hasattr(args, "water_size")) or args.water_size == None:
|
||||
args.water_size = 100
|
||||
if (not hasattr(args, "batch_size")) or args.batch_size == None:
|
||||
args.batch_size = 64
|
||||
if (not hasattr(args, "feature_dim")) or args.feature_dim == None:
|
||||
|
@ -53,7 +55,7 @@ def run(args):
|
|||
feature_dim=args.feature_dim,
|
||||
|
||||
metadata = read_metadata(args.input),
|
||||
shape_water=[ 100, 100 ] # The DESIRED
|
||||
shape_water=[ args.shape_water, args.shape_water ] # The DESIRED output shape. the actual data will be cropped to match this.
|
||||
)
|
||||
|
||||
ai.train(dataset_train, dataset_validate)
|
||||
|
|
Loading…
Reference in a new issue