mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
add water thresholding
This commit is contained in:
parent
404dc30f08
commit
4ee7f2a0d6
3 changed files with 17 additions and 9 deletions
|
@ -22,7 +22,7 @@ def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out,
|
|||
|
||||
# TODO: An attention layer here instead of a dense layer, with a skip connection perhaps?
|
||||
layer_next = tf.keras.layers.Dense(32)(layer_next)
|
||||
layer_next = tf.keras.layers.Conv2D(out_water_channels, kernel_size=1, activation="softmax", padding="same")(layer_next)
|
||||
layer_next = tf.keras.layers.Conv2D(1, kernel_size=1, activation="softmax", padding="same")(layer_next)
|
||||
|
||||
model = tf.keras.Model(
|
||||
inputs = layer_input,
|
||||
|
@ -31,7 +31,8 @@ def model_rainfallwater_segmentation(metadata, feature_dim_in, shape_water_out,
|
|||
|
||||
model.compile(
|
||||
optimizer="Adam",
|
||||
loss="" # TODO: set this to binary cross-entropy loss
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
metrics=["accuracy"]
|
||||
)
|
||||
|
||||
return model
|
|
@ -14,7 +14,7 @@ from .shuffle import shuffle
|
|||
|
||||
|
||||
# TO PARSE:
|
||||
def parse_item(metadata, shape_water_desired):
|
||||
def parse_item(metadata, shape_water_desired, water_threshold=0.1):
|
||||
water_width_source, water_height_source, _water_channels_source = metadata["waterdepth"]
|
||||
water_width_target, water_height_target = shape_water_desired
|
||||
water_offset_x = math.ceil((water_width_source - water_width_target) / 2)
|
||||
|
@ -34,6 +34,8 @@ def parse_item(metadata, shape_water_desired):
|
|||
# rainfall = [ feature_dim ]
|
||||
# water = [ width, height, 1 ]
|
||||
|
||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
||||
|
||||
water = tf.image.crop_to_bounding_box(water, water_offset_x, water_offset_y, water_width_target, water_height_target)
|
||||
|
||||
print("DEBUG:dataset ITEM rainfall:shape", rainfall.shape, "water:shape", water.shape)
|
||||
|
@ -70,7 +72,7 @@ def get_filepaths(dirpath_input):
|
|||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||
)))
|
||||
|
||||
def dataset_segmenter(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||
def dataset_segmenter(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5, water_threshold=0.1):
|
||||
filepaths = get_filepaths(dirpath_input)
|
||||
filepaths_count = len(filepaths)
|
||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||
|
@ -80,12 +82,12 @@ def dataset_segmenter(dirpath_input, batch_size=64, train_percentage=0.8, parall
|
|||
|
||||
metadata = read_metadata(dirpath_input)
|
||||
|
||||
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, water_threshold=water_threshold)
|
||||
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, water_threshold=water_threshold)
|
||||
|
||||
return dataset_train, dataset_validate #, filepaths
|
||||
|
||||
def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True):
|
||||
def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True, water_threshold=0.1):
|
||||
filepaths = get_filepaths(dirpath_input) if os.path.isdir(dirpath_input) else [ dirpath_input ]
|
||||
|
||||
return make_dataset(
|
||||
|
@ -94,7 +96,8 @@ def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True)
|
|||
parallel_reads_multiplier=parallel_reads_multiplier,
|
||||
batch_size=None,
|
||||
prefetch=prefetch,
|
||||
shuffle=False #even with shuffle=False we're not gonna get them all in the same order since we're reading in parallel
|
||||
shuffle=False, #even with shuffle=False we're not gonna get them all in the same order since we're reading in parallel
|
||||
water_threshold=water_threshold
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -18,6 +18,7 @@ def parse_args():
|
|||
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
|
||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||
parser.add_argument("--water-size", help="The width and height of the square of pixels that the model will predict. Smaller values crop the input more [default: 100].", type=int)
|
||||
parser.add_argument("--water-threshold", help="The threshold at which a water cell should be considered water. Water depth values lower than this will be set to 0 (no water). Value unit is metres [default: 0.1].", type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -30,6 +31,8 @@ def run(args):
|
|||
args.feature_dim = 512
|
||||
if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None:
|
||||
args.read_multiplier = 1.5
|
||||
if (not hasattr(args, "water_threshold")) or args.water_threshold == None:
|
||||
args.water_threshold = 1.5
|
||||
|
||||
|
||||
# TODO: Validate args here.
|
||||
|
@ -39,6 +42,7 @@ def run(args):
|
|||
dataset_train, dataset_validate = dataset_segmenter(
|
||||
dirpath_input=args.input,
|
||||
batch_size=args.batch_size,
|
||||
water_threshold=args.water_threshold,
|
||||
)
|
||||
dataset_metadata = read_metadata(args.input)
|
||||
|
||||
|
@ -55,7 +59,7 @@ def run(args):
|
|||
feature_dim_in=args.feature_dim,
|
||||
|
||||
metadata = read_metadata(args.input),
|
||||
shape_water_out=[ args.water_size, args.water_size ] # The DESIRED output shape. the actual data will be cropped to match this.
|
||||
shape_water_out=[ args.water_size, args.water_size ], # The DESIRED output shape. the actual data will be cropped to match this.
|
||||
)
|
||||
|
||||
ai.train(dataset_train, dataset_validate)
|
||||
|
|
Loading…
Reference in a new issue