From 9f3ae96894ce359c5031571f1d53400194eaf638 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 6 Oct 2022 19:21:50 +0100 Subject: [PATCH] finish wiring for --water-size --- aimodel/src/lib/dataset/dataset_segmenter.py | 4 ++-- aimodel/src/subcommands/train.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aimodel/src/lib/dataset/dataset_segmenter.py b/aimodel/src/lib/dataset/dataset_segmenter.py index a864a4b..e59cb2d 100644 --- a/aimodel/src/lib/dataset/dataset_segmenter.py +++ b/aimodel/src/lib/dataset/dataset_segmenter.py @@ -46,7 +46,7 @@ def parse_item(metadata, shape_water_desired, water_threshold=0.1): return tf.function(parse_item_inner) -def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], water_threshold=0.1, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, prefetch=True, shuffle=True): +def make_dataset(filepaths, metadata, shape_water_desired=[100,100], water_threshold=0.1, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64, prefetch=True, shuffle=True): if "NO_PREFETCH" in os.environ: logger.info("disabling data prefetching.") @@ -56,7 +56,7 @@ def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], water_thres ) if shuffle: dataset = dataset.shuffle(shuffle_buffer_size) - dataset = dataset.map(parse_item(metadata, shape_water_desired=shape_watch_desired, water_threshold=water_threshold), num_parallel_calls=tf.data.AUTOTUNE) + dataset = dataset.map(parse_item(metadata, shape_water_desired=shape_water_desired, water_threshold=water_threshold), num_parallel_calls=tf.data.AUTOTUNE) if batch_size != None: dataset = dataset.batch(batch_size, drop_remainder=True) diff --git a/aimodel/src/subcommands/train.py b/aimodel/src/subcommands/train.py index c95a8d2..e47d6ce 100644 --- a/aimodel/src/subcommands/train.py +++ b/aimodel/src/subcommands/train.py @@ -33,6 +33,8 @@ def run(args): args.read_multiplier = 1.5 if (not hasattr(args, "water_threshold")) or args.water_threshold == None: args.water_threshold = 1.5 + if (not hasattr(args, "water_size")) or args.water_size == None: + args.water_size = 1.5 # TODO: Validate args here. @@ -43,6 +45,7 @@ def run(args): dirpath_input=args.input, batch_size=args.batch_size, water_threshold=args.water_threshold, + shape_water_desired=[args.water_size, args.water_size] ) dataset_metadata = read_metadata(args.input)