mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-24 10:13:00 +00:00
dataset_mono: Implement validate_percentage + test_percentage support
This removes the train_percentage argument TODO: map this forwards to enable support in deeplabv3_plus_test_rainfall ...thinking about it, it's really not a test now, is it? Updating the filename would be such a /hassle/ though....
This commit is contained in:
parent
5d62e3cee8
commit
c98c42fa7e
2 changed files with 38 additions and 12 deletions
|
@ -16,6 +16,23 @@ from .primitives.remove_isolated_pixels import remove_isolated_pixels
|
|||
|
||||
# TO PARSE:
|
||||
def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1, do_remove_isolated_pixels=True):
|
||||
"""
|
||||
Parse a single TFRecord item from the dataset.
|
||||
|
||||
Args:
|
||||
metadata (dict): Metadata about the shapes of the dataset - rainfall radar, water depth data etc. This should be read automaticallyfrom the metadata.json file that's generated by previous pipeline steps that I forget at this time.
|
||||
output_size (int): The desired output size of the water depth data.
|
||||
input_size (str or int): The desired input size of the rainfall radar data. If "same", it will be set to the same as the output_size.
|
||||
water_threshold (float): The threshold to use for binarizing the water depth data.
|
||||
water_bins (int): The number of bins to use for the water depth data (e.g. for one-hot encoding).
|
||||
heightmap (tf.Tensor): An optional heightmap to include as an additional channel in the rainfall radar data.
|
||||
rainfall_scale_up (int): A factor to scale up the rainfall radar data.
|
||||
do_remove_isolated_pixels (bool): Whether to remove isolated pixels from the water depth data or not. Isolated pixels are binaried [=1] pixels that are surrounded on (4|8 TODO FIGURE OUT) sides.
|
||||
|
||||
Returns:
|
||||
A function that takes a single TFRecord item and returns the parsed rainfall radar and water depth data.
|
||||
"""
|
||||
|
||||
if input_size == "same":
|
||||
input_size = output_size # This is almost always the case with e.g. the DeepLabV3+ model
|
||||
|
||||
|
@ -144,22 +161,31 @@ def get_filepaths(dirpath_input, do_shuffle=True):
|
|||
return result
|
||||
|
||||
# TODO refactor this to validate_percentage=0.2 and test_percentage=0, but DON'T FORGET TO CHECK ***ALL*** usages of this FIRST and update them afterwards!
|
||||
def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs):
|
||||
def dataset_mono(dirpath_input, validate_percentage=0.2, test_percentage=0, **kwargs):
|
||||
filepaths = get_filepaths(dirpath_input)
|
||||
filepaths_count = len(filepaths)
|
||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||
|
||||
filepaths_train = filepaths[:dataset_splitpoint]
|
||||
filepaths_validate = filepaths[dataset_splitpoint:]
|
||||
split_trainvalidate=math.floor(filepaths_count * (1-(validate_percentage+test_percentage)))
|
||||
split_validatetest=math.floor(filepaths * (1 - test_percentage))
|
||||
|
||||
print("DEBUG:dataset_mono filepaths_train", filepaths_train, "filepaths_validate", filepaths_validate)
|
||||
|
||||
filepaths_train = filepaths[:split_trainvalidate]
|
||||
filepaths_validate = filepaths[split_trainvalidate:split_validatetest]
|
||||
filepaths_test = []
|
||||
if test_percentage > 0:
|
||||
filepaths_test = filepaths[split_validatetest:]
|
||||
|
||||
print("DEBUG:dataset_mono filepaths_train", filepaths_train, "filepaths_validate", filepaths_validate, "filepaths_test", filepaths_test)
|
||||
|
||||
metadata = read_metadata(dirpath_input)
|
||||
|
||||
dataset_train = make_dataset(filepaths_train, metadata=metadata, **kwargs)
|
||||
dataset_validate = make_dataset(filepaths_validate, metadata=metadata, **kwargs)
|
||||
dataset_test = None
|
||||
if test_percentage > 0:
|
||||
dataset_test = make_dataset(filepaths_test, metadata=metadata, **kwargs)
|
||||
|
||||
return dataset_train, dataset_validate #, filepaths
|
||||
return dataset_train, dataset_validate, dataset_test #, filepaths
|
||||
|
||||
def dataset_mono_predict(dirpath_input, batch_size=64, **kwargs):
|
||||
"""Creates a tf.data.Dataset() for prediction using the contrastive learning model.
|
||||
|
|
Loading…
Reference in a new issue