rename shuffle arg

This commit is contained in:
Starbeamrainbowlabs 2022-10-21 16:35:45 +01:00
parent c98d8d05dd
commit 612735aaae
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 6 additions and 6 deletions

View file

@ -66,12 +66,12 @@ def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression
return dataset return dataset
def get_filepaths(dirpath_input, shuffle=True): def get_filepaths(dirpath_input, do_shuffle=True):
result = list(filter( result = list(filter(
lambda filepath: str(filepath).endswith(".tfrecord.gz"), lambda filepath: str(filepath).endswith(".tfrecord.gz"),
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath [ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
)) ))
if shuffle: if do_shuffle:
result = shuffle(result) result = shuffle(result)
else: else:
result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0])) result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0]))
@ -105,7 +105,7 @@ def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True)
Returns: Returns:
tf.data.Dataset: A tensorflow Dataset for the given input files. tf.data.Dataset: A tensorflow Dataset for the given input files.
""" """
filepaths = get_filepaths(dirpath_input, shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ] filepaths = get_filepaths(dirpath_input, do_shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ]
return make_dataset( return make_dataset(
filepaths=filepaths, filepaths=filepaths,

View file

@ -63,12 +63,12 @@ def make_dataset(filepaths, metadata, shape_water_desired=[100,100], water_thres
return dataset return dataset
def get_filepaths(dirpath_input, shuffle=True): def get_filepaths(dirpath_input, do_shuffle=True):
result = list(filter( result = list(filter(
lambda filepath: str(filepath).endswith(".tfrecord.gz"), lambda filepath: str(filepath).endswith(".tfrecord.gz"),
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath [ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
)) ))
if shuffle: if do_shuffle:
result = shuffle(result) result = shuffle(result)
else: else:
result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0])) result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0]))
@ -102,7 +102,7 @@ def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True,
Returns: Returns:
tf.data.Dataset: A tensorflow Dataset for the given input files. tf.data.Dataset: A tensorflow Dataset for the given input files.
""" """
filepaths = get_filepaths(dirpath_input, shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ] filepaths = get_filepaths(dirpath_input, do_shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ]
return make_dataset( return make_dataset(
filepaths=filepaths, filepaths=filepaths,