train: add --arch; default to convnext_i_xtiny

This commit is contained in:
Starbeamrainbowlabs 2022-10-11 19:18:01 +01:00
parent 5666c5a0d9
commit 11f91a7cf4
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 7 additions and 2 deletions

View file

@ -15,7 +15,7 @@ depths_dims = dict(
)
def do_convnext_inverse(layer_in, arch_name="convnext_tiny"):
def do_convnext_inverse(layer_in, arch_name="convnext_i_xtiny"):
return convnext_inverse(layer_in,
depths=depths_dims[arch_name]["depths"],
dims=depths_dims[arch_name]["dims"]

View file

@ -5,7 +5,7 @@ import tensorflow as tf
from .components.convnext_inverse import do_convnext_inverse
def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out, batch_size=64, summary_file=None):
def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, summary_file=None):
out_water_width, out_water_height = shape_water_out

View file

@ -19,6 +19,8 @@ def parse_args():
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
parser.add_argument("--water-threshold", help="The threshold at which a water cell should be considered water. Water depth values lower than this will be set to 0 (no water). Value unit is metres [default: 0.1].", type=int)
parser.add_argument("--arch", help="Next fo the underlying convnext model to use [default: 0.1].", type=int)
return parser
@ -35,6 +37,8 @@ def run(args):
args.water_threshold = 1.5
if (not hasattr(args, "water_size")) or args.water_size == None:
args.water_size = 1.5
if (not hasattr(args, "arch")) or args.arch == None:
args.arch = "convnext_i_xtiny"
# TODO: Validate args here.
@ -61,6 +65,7 @@ def run(args):
batch_size=args.batch_size,
feature_dim_in=args.feature_dim,
model_arch=args.model_arch,
metadata = read_metadata(args.input),
shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
)