train-mono: tidy up arg passing

This commit is contained in:
Starbeamrainbowlabs 2022-12-08 18:47:03 +00:00
parent b53db648bf
commit d37e7224f5
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 3 additions and 5 deletions

View file

@ -8,7 +8,7 @@ from .components.convnext_inverse import do_convnext_inverse
from .components.LayerStack2Image import LayerStack2Image
from .components.LossCrossentropy import LossCrossentropy
def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2, learning_rate=None, heightmap_input=False):
def model_rainfallwater_mono(metadata, model_arch_enc="convnext_xtiny", model_arch_dec="convnext_i_xtiny", feature_dim=512, batch_size=64, water_bins=2, learning_rate=None, heightmap_input=False):
"""Makes a new rainfall / waterdepth mono model.
Args:
@ -32,7 +32,7 @@ def model_rainfallwater_mono(metadata, shape_water_out, model_arch_enc="convnext
rainfall_channels += 1
print("RAINFALL channels", rainfall_channels, "width", rainfall_width, "height", rainfall_height, "HEIGHTMAP_INPUT", heightmap_input)
out_water_width, out_water_height = shape_water_out
layer_input = tf.keras.layers.Input(
shape=(rainfall_width, rainfall_height, rainfall_channels)

View file

@ -37,8 +37,6 @@ def run(args):
args.read_multiplier = 1.5
if (not hasattr(args, "water_threshold")) or args.water_threshold == None:
args.water_threshold = 0.1
if (not hasattr(args, "water_size")) or args.water_size == None:
args.water_size = 1.5
if (not hasattr(args, "bottleneck")) or args.bottleneck == None:
args.bottleneck = 512
if (not hasattr(args, "arch_enc")) or args.arch_enc == None:
@ -83,7 +81,7 @@ def run(args):
learning_rate = args.learning_rate,
metadata = read_metadata(args.input),
shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
# shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
)
ai.train(dataset_train, dataset_validate)