diff --git a/aimodel/src/lib/dataset/dataset_segmenter.py b/aimodel/src/lib/dataset/dataset_segmenter.py index e59cb2d..d191bca 100644 --- a/aimodel/src/lib/dataset/dataset_segmenter.py +++ b/aimodel/src/lib/dataset/dataset_segmenter.py @@ -72,7 +72,7 @@ def get_filepaths(dirpath_input): [ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath ))) -def dataset_segmenter(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5, water_threshold=0.1): +def dataset_segmenter(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5, water_threshold=0.1, shape_water_desired=[100,100]): filepaths = get_filepaths(dirpath_input) filepaths_count = len(filepaths) dataset_splitpoint = math.floor(filepaths_count * train_percentage) @@ -82,8 +82,8 @@ def dataset_segmenter(dirpath_input, batch_size=64, train_percentage=0.8, parall metadata = read_metadata(dirpath_input) - dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, water_threshold=water_threshold) - dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, water_threshold=water_threshold) + dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, water_threshold=water_threshold, shape_water_desired=shape_water_desired) + dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier, water_threshold=water_threshold, shape_water_desired=shape_water_desired) return dataset_train, dataset_validate #, filepaths