mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
No need for a CLI arg for feature_dim_in - metadata should contain this
This commit is contained in:
parent
e201372252
commit
f12e6ab905
2 changed files with 18 additions and 5 deletions
|
@ -5,9 +5,24 @@ import tensorflow as tf
|
||||||
|
|
||||||
from .components.convnext_inverse import do_convnext_inverse
|
from .components.convnext_inverse import do_convnext_inverse
|
||||||
|
|
||||||
def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, summary_file=None):
|
|
||||||
|
def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, summary_file=None, water_bins=2):
|
||||||
|
"""Makes a new rainfall / waterdepth segmentation head model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): A dictionary of metadata about the dataset to use to build the model with.
|
||||||
|
feature_dim_in (int): The size of the feature dimension
|
||||||
|
shape_water_out (_type_): _description_
|
||||||
|
model_arch (str, optional): _description_. Defaults to "convnext_i_xtiny".
|
||||||
|
batch_size (int, optional): _description_. Defaults to 64.
|
||||||
|
summary_file (_type_, optional): _description_. Defaults to None.
|
||||||
|
water_bins (int, optional): _description_. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
out_water_width, out_water_height = shape_water_out
|
out_water_width, out_water_height = shape_water_out
|
||||||
|
feature_dim_in = metadata["rainfallradar"][0]
|
||||||
|
|
||||||
layer_input = tf.keras.layers.Input(
|
layer_input = tf.keras.layers.Input(
|
||||||
shape=(feature_dim_in)
|
shape=(feature_dim_in)
|
||||||
|
@ -31,7 +46,7 @@ def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out,
|
||||||
# TODO: An attention layer here instead of a dense layer, with a skip connection perhaps?
|
# TODO: An attention layer here instead of a dense layer, with a skip connection perhaps?
|
||||||
logger.warning("Warning: TODO implement attention from https://ieeexplore.ieee.org/document/9076883")
|
logger.warning("Warning: TODO implement attention from https://ieeexplore.ieee.org/document/9076883")
|
||||||
layer_next = tf.keras.layers.Dense(32)(layer_next)
|
layer_next = tf.keras.layers.Dense(32)(layer_next)
|
||||||
layer_next = tf.keras.layers.Conv2D(1, kernel_size=1, activation="softmax", padding="same")(layer_next)
|
layer_next = tf.keras.layers.Conv2D(water_bins, kernel_size=1, activation="softmax", padding="same")(layer_next)
|
||||||
|
|
||||||
model = tf.keras.Model(
|
model = tf.keras.Model(
|
||||||
inputs = layer_input,
|
inputs = layer_input,
|
||||||
|
|
|
@ -14,7 +14,6 @@ def parse_args():
|
||||||
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
||||||
parser.add_argument("--input", "-i", help="Path to input directory containing the .tfrecord.gz files to pretrain with", required=True)
|
parser.add_argument("--input", "-i", help="Path to input directory containing the .tfrecord.gz files to pretrain with", required=True)
|
||||||
parser.add_argument("--output", "-o", help="Path to output directory to write output to (will be automatically created if it doesn't exist)", required=True)
|
parser.add_argument("--output", "-o", help="Path to output directory to write output to (will be automatically created if it doesn't exist)", required=True)
|
||||||
parser.add_argument("--feature-dim", help="The size of the input feature dimension of the model [default: 512].", type=int)
|
|
||||||
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
|
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
|
||||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||||
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
|
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
|
||||||
|
@ -63,7 +62,6 @@ def run(args):
|
||||||
ai = RainfallWaterSegmenter(
|
ai = RainfallWaterSegmenter(
|
||||||
dir_output=args.output,
|
dir_output=args.output,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
feature_dim_in=args.feature_dim,
|
|
||||||
|
|
||||||
model_arch=args.arch,
|
model_arch=args.arch,
|
||||||
metadata = read_metadata(args.input),
|
metadata = read_metadata(args.input),
|
||||||
|
|
Loading…
Reference in a new issue