mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
train: add --arch; default to convnext_i_xtiny
This commit is contained in:
parent
5666c5a0d9
commit
11f91a7cf4
3 changed files with 7 additions and 2 deletions
|
@ -15,7 +15,7 @@ depths_dims = dict(
|
|||
)
|
||||
|
||||
|
||||
def do_convnext_inverse(layer_in, arch_name="convnext_tiny"):
|
||||
def do_convnext_inverse(layer_in, arch_name="convnext_i_xtiny"):
|
||||
return convnext_inverse(layer_in,
|
||||
depths=depths_dims[arch_name]["depths"],
|
||||
dims=depths_dims[arch_name]["dims"]
|
||||
|
|
|
@ -5,7 +5,7 @@ import tensorflow as tf
|
|||
|
||||
from .components.convnext_inverse import do_convnext_inverse
|
||||
|
||||
def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out, batch_size=64, summary_file=None):
|
||||
def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, summary_file=None):
|
||||
out_water_width, out_water_height = shape_water_out
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ def parse_args():
|
|||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
|
||||
parser.add_argument("--water-threshold", help="The threshold at which a water cell should be considered water. Water depth values lower than this will be set to 0 (no water). Value unit is metres [default: 0.1].", type=int)
|
||||
parser.add_argument("--arch", help="Next fo the underlying convnext model to use [default: 0.1].", type=int)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -35,6 +37,8 @@ def run(args):
|
|||
args.water_threshold = 1.5
|
||||
if (not hasattr(args, "water_size")) or args.water_size == None:
|
||||
args.water_size = 1.5
|
||||
if (not hasattr(args, "arch")) or args.arch == None:
|
||||
args.arch = "convnext_i_xtiny"
|
||||
|
||||
|
||||
# TODO: Validate args here.
|
||||
|
@ -61,6 +65,7 @@ def run(args):
|
|||
batch_size=args.batch_size,
|
||||
feature_dim_in=args.feature_dim,
|
||||
|
||||
model_arch=args.model_arch,
|
||||
metadata = read_metadata(args.input),
|
||||
shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue