mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
train-mono: tidy up arg passing
This commit is contained in:
parent
b53db648bf
commit
d37e7224f5
2 changed files with 3 additions and 5 deletions
|
@ -8,7 +8,7 @@ from .components.convnext_inverse import do_convnext_inverse
|
||||||
from .components.LayerStack2Image import LayerStack2Image
|
from .components.LayerStack2Image import LayerStack2Image
|
||||||
from .components.LossCrossentropy import LossCrossentropy
|
from .components.LossCrossentropy import LossCrossentropy
|
||||||
|
|
||||||
def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2, learning_rate=None, heightmap_input=False):
|
def model_rainfallwater_mono(metadata, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2, learning_rate=None, heightmap_input=False):
|
||||||
"""Makes a new rainfall / waterdepth mono model.
|
"""Makes a new rainfall / waterdepth mono model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -32,7 +32,7 @@ def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext
|
||||||
rainfall_channels += 1
|
rainfall_channels += 1
|
||||||
|
|
||||||
print("RAINFALL channels", rainfall_channels, "width", rainfall_width, "height", rainfall_height, "HEIGHTMAP_INPUT", heightmap_input)
|
print("RAINFALL channels", rainfall_channels, "width", rainfall_width, "height", rainfall_height, "HEIGHTMAP_INPUT", heightmap_input)
|
||||||
out_water_width, out_water_height = shape_water_out
|
|
||||||
|
|
||||||
layer_input = tf.keras.layers.Input(
|
layer_input = tf.keras.layers.Input(
|
||||||
shape=(rainfall_width, rainfall_height, rainfall_channels)
|
shape=(rainfall_width, rainfall_height, rainfall_channels)
|
||||||
|
|
|
@ -37,8 +37,6 @@ def run(args):
|
||||||
args.read_multiplier = 1.5
|
args.read_multiplier = 1.5
|
||||||
if (not hasattr(args, "water_threshold")) or args.water_threshold == None:
|
if (not hasattr(args, "water_threshold")) or args.water_threshold == None:
|
||||||
args.water_threshold = 0.1
|
args.water_threshold = 0.1
|
||||||
if (not hasattr(args, "water_size")) or args.water_size == None:
|
|
||||||
args.water_size = 1.5
|
|
||||||
if (not hasattr(args, "bottleneck")) or args.bottleneck == None:
|
if (not hasattr(args, "bottleneck")) or args.bottleneck == None:
|
||||||
args.bottleneck = 512
|
args.bottleneck = 512
|
||||||
if (not hasattr(args, "arch_enc")) or args.arch_enc == None:
|
if (not hasattr(args, "arch_enc")) or args.arch_enc == None:
|
||||||
|
@ -83,7 +81,7 @@ def run(args):
|
||||||
learning_rate = args.learning_rate,
|
learning_rate = args.learning_rate,
|
||||||
|
|
||||||
metadata = read_metadata(args.input),
|
metadata = read_metadata(args.input),
|
||||||
shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
|
# shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
|
||||||
)
|
)
|
||||||
|
|
||||||
ai.train(dataset_train, dataset_validate)
|
ai.train(dataset_train, dataset_validate)
|
||||||
|
|
Loading…
Reference in a new issue