From c98c42fa7ef7ec9274327ed720ebc45609dd9deb Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 29 Aug 2024 16:44:19 +0100 Subject: [PATCH] dataset_mono: Implement validate_percentage + test_percentage support This removes the train_percentage argument TODO: map this forwards to enable support in deeplabv3_plus_test_rainfall ...thinking about it, it's really not a test now, is it? Updating the filename would be such a /hassle/ though.... --- aimodel/src/lib/dataset/dataset_mono.py | 38 +++++++++++++++++++++---- aimodel/src/subcommands/train_mono.py | 12 ++++---- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/aimodel/src/lib/dataset/dataset_mono.py b/aimodel/src/lib/dataset/dataset_mono.py index 7d0b560..2e363fa 100644 --- a/aimodel/src/lib/dataset/dataset_mono.py +++ b/aimodel/src/lib/dataset/dataset_mono.py @@ -16,6 +16,23 @@ from .primitives.remove_isolated_pixels import remove_isolated_pixels # TO PARSE: def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1, water_bins=2, heightmap=None, rainfall_scale_up=1, do_remove_isolated_pixels=True): + """ + Parse a single TFRecord item from the dataset. + + Args: + metadata (dict): Metadata about the shapes of the dataset - rainfall radar, water depth data etc. This should be read automaticallyfrom the metadata.json file that's generated by previous pipeline steps that I forget at this time. + output_size (int): The desired output size of the water depth data. + input_size (str or int): The desired input size of the rainfall radar data. If "same", it will be set to the same as the output_size. + water_threshold (float): The threshold to use for binarizing the water depth data. + water_bins (int): The number of bins to use for the water depth data (e.g. for one-hot encoding). + heightmap (tf.Tensor): An optional heightmap to include as an additional channel in the rainfall radar data. + rainfall_scale_up (int): A factor to scale up the rainfall radar data. + do_remove_isolated_pixels (bool): Whether to remove isolated pixels from the water depth data or not. Isolated pixels are binaried [=1] pixels that are surrounded on (4|8 TODO FIGURE OUT) sides. + + Returns: + A function that takes a single TFRecord item and returns the parsed rainfall radar and water depth data. + """ + if input_size == "same": input_size = output_size # This is almost always the case with e.g. the DeepLabV3+ model @@ -144,22 +161,31 @@ def get_filepaths(dirpath_input, do_shuffle=True): return result # TODO refactor this to validate_percentage=0.2 and test_percentage=0, but DON'T FORGET TO CHECK ***ALL*** usages of this FIRST and update them afterwards! -def dataset_mono(dirpath_input, train_percentage=0.8, **kwargs): +def dataset_mono(dirpath_input, validate_percentage=0.2, test_percentage=0, **kwargs): filepaths = get_filepaths(dirpath_input) filepaths_count = len(filepaths) - dataset_splitpoint = math.floor(filepaths_count * train_percentage) - filepaths_train = filepaths[:dataset_splitpoint] - filepaths_validate = filepaths[dataset_splitpoint:] + split_trainvalidate=math.floor(filepaths_count * (1-(validate_percentage+test_percentage))) + split_validatetest=math.floor(filepaths * (1 - test_percentage)) - print("DEBUG:dataset_mono filepaths_train", filepaths_train, "filepaths_validate", filepaths_validate) + + filepaths_train = filepaths[:split_trainvalidate] + filepaths_validate = filepaths[split_trainvalidate:split_validatetest] + filepaths_test = [] + if test_percentage > 0: + filepaths_test = filepaths[split_validatetest:] + + print("DEBUG:dataset_mono filepaths_train", filepaths_train, "filepaths_validate", filepaths_validate, "filepaths_test", filepaths_test) metadata = read_metadata(dirpath_input) dataset_train = make_dataset(filepaths_train, metadata=metadata, **kwargs) dataset_validate = make_dataset(filepaths_validate, metadata=metadata, **kwargs) + dataset_test = None + if test_percentage > 0: + dataset_test = make_dataset(filepaths_test, metadata=metadata, **kwargs) - return dataset_train, dataset_validate #, filepaths + return dataset_train, dataset_validate, dataset_test #, filepaths def dataset_mono_predict(dirpath_input, batch_size=64, **kwargs): """Creates a tf.data.Dataset() for prediction using the contrastive learning model. diff --git a/aimodel/src/subcommands/train_mono.py b/aimodel/src/subcommands/train_mono.py index 08be066..88e7592 100644 --- a/aimodel/src/subcommands/train_mono.py +++ b/aimodel/src/subcommands/train_mono.py @@ -56,12 +56,12 @@ def run(args): dataset_train, dataset_validate = dataset_mono( - dirpath_input=args.input, - batch_size=args.batch_size, - water_threshold=args.water_threshold, - output_size=args.water_size, - input_size=None, # Don't crop the input size - filepath_heightmap=args.heightmap + dirpath_input = args.input, + batch_size = args.batch_size, + water_threshold = args.water_threshold, + output_size = args.water_size, + input_size = None, # Don't crop the input size + filepath_heightmap = args.heightmap ) dataset_metadata = read_metadata(args.input)