mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
rename shuffle arg
This commit is contained in:
parent
c98d8d05dd
commit
612735aaae
2 changed files with 6 additions and 6 deletions
|
@ -66,12 +66,12 @@ def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def get_filepaths(dirpath_input, shuffle=True):
|
def get_filepaths(dirpath_input, do_shuffle=True):
|
||||||
result = list(filter(
|
result = list(filter(
|
||||||
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||||
))
|
))
|
||||||
if shuffle:
|
if do_shuffle:
|
||||||
result = shuffle(result)
|
result = shuffle(result)
|
||||||
else:
|
else:
|
||||||
result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0]))
|
result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0]))
|
||||||
|
@ -105,7 +105,7 @@ def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True)
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset: A tensorflow Dataset for the given input files.
|
tf.data.Dataset: A tensorflow Dataset for the given input files.
|
||||||
"""
|
"""
|
||||||
filepaths = get_filepaths(dirpath_input, shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ]
|
filepaths = get_filepaths(dirpath_input, do_shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ]
|
||||||
|
|
||||||
return make_dataset(
|
return make_dataset(
|
||||||
filepaths=filepaths,
|
filepaths=filepaths,
|
||||||
|
|
|
@ -63,12 +63,12 @@ def make_dataset(filepaths, metadata, shape_water_desired=[100,100], water_thres
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def get_filepaths(dirpath_input, shuffle=True):
|
def get_filepaths(dirpath_input, do_shuffle=True):
|
||||||
result = list(filter(
|
result = list(filter(
|
||||||
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||||
))
|
))
|
||||||
if shuffle:
|
if do_shuffle:
|
||||||
result = shuffle(result)
|
result = shuffle(result)
|
||||||
else:
|
else:
|
||||||
result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0]))
|
result = sorted(result, key=lambda filepath: int(os.path.basename(filepath).split(".", 1)[0]))
|
||||||
|
@ -102,7 +102,7 @@ def dataset_predict(dirpath_input, parallel_reads_multiplier=1.5, prefetch=True,
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset: A tensorflow Dataset for the given input files.
|
tf.data.Dataset: A tensorflow Dataset for the given input files.
|
||||||
"""
|
"""
|
||||||
filepaths = get_filepaths(dirpath_input, shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ]
|
filepaths = get_filepaths(dirpath_input, do_shuffle=False) if os.path.isdir(dirpath_input) else [ dirpath_input ]
|
||||||
|
|
||||||
return make_dataset(
|
return make_dataset(
|
||||||
filepaths=filepaths,
|
filepaths=filepaths,
|
||||||
|
|
Loading…
Reference in a new issue