dataset_mono: Implement validate_percentage + test_percentage support

This removes the train_percentage argument

TODO: map this forwards to enable support in deeplabv3_plus_test_rainfall

...thinking about it, it's really not a test now, is it? Updating the filename would be such a /hassle/ though....
This commit is contained in:
Starbeamrainbowlabs 2024-08-29 16:44:19 +01:00
parent 5d62e3cee8
commit c98c42fa7e
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 38 additions and 12 deletions

View file

@ -16,6 +16,23 @@ from .primitives.remove_isolated_pixels import remove_isolated_pixels
# TO PARSE:
def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1, do_remove_isolated_pixels=True):
"""
Parse a single TFRecord item from the dataset.
Args:
metadata (dict): Metadata about the shapes of the dataset - rainfall radar, water depth data etc. This should be read automaticallyfrom the metadata.json file that's generated by previous pipeline steps that I forget at this time.
output_size (int): The desired output size of the water depth data.
input_size (str or int): The desired input size of the rainfall radar data. If "same", it will be set to the same as the output_size.
water_threshold (float): The threshold to use for binarizing the water depth data.
water_bins (int): The number of bins to use for the water depth data (e.g. for one-hot encoding).
heightmap (tf.Tensor): An optional heightmap to include as an additional channel in the rainfall radar data.
rainfall_scale_up (int): A factor to scale up the rainfall radar data.
do_remove_isolated_pixels (bool): Whether to remove isolated pixels from the water depth data or not. Isolated pixels are binaried [=1] pixels that are surrounded on (4|8 TODO FIGURE OUT) sides.
Returns:
A function that takes a single TFRecord item and returns the parsed rainfall radar and water depth data.
"""
if input_size == "same":
input_size = output_size # This is almost always the case with e.g. the DeepLabV3+ model
@ -144,22 +161,31 @@ def get_filepaths(dirpath_input, do_shuffle=True):
return result
# TODO refactor this to validate_percentage=0.2 and test_percentage=0, but DON'T FORGET TO CHECK ***ALL*** usages of this FIRST and update them afterwards!
def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs):
def dataset_mono(dirpath_input, validate_percentage=0.2, test_percentage=0, **kwargs):
filepaths = get_filepaths(dirpath_input)
filepaths_count = len(filepaths)
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
filepaths_train = filepaths[:dataset_splitpoint]
filepaths_validate = filepaths[dataset_splitpoint:]
split_trainvalidate=math.floor(filepaths_count * (1-(validate_percentage+test_percentage)))
split_validatetest=math.floor(filepaths * (1 - test_percentage))
print("DEBUG:dataset_mono filepaths_train", filepaths_train, "filepaths_validate", filepaths_validate)
filepaths_train = filepaths[:split_trainvalidate]
filepaths_validate = filepaths[split_trainvalidate:split_validatetest]
filepaths_test = []
if test_percentage > 0:
filepaths_test = filepaths[split_validatetest:]
print("DEBUG:dataset_mono filepaths_train", filepaths_train, "filepaths_validate", filepaths_validate, "filepaths_test", filepaths_test)
metadata = read_metadata(dirpath_input)
dataset_train = make_dataset(filepaths_train, metadata=metadata, **kwargs)
dataset_validate = make_dataset(filepaths_validate, metadata=metadata, **kwargs)
dataset_test = None
if test_percentage > 0:
dataset_test = make_dataset(filepaths_test, metadata=metadata, **kwargs)
return dataset_train, dataset_validate #, filepaths
return dataset_train, dataset_validate, dataset_test #, filepaths
def dataset_mono_predict(dirpath_input, batch_size=64, **kwargs):
"""Creates a tf.data.Dataset() for prediction using the contrastive learning model.