diff --git a/aimodel/src/subcommands/pretrain.py b/aimodel/src/subcommands/pretrain.py index 2506b8f..5a2f68a 100644 --- a/aimodel/src/subcommands/pretrain.py +++ b/aimodel/src/subcommands/pretrain.py @@ -21,6 +21,13 @@ def parse_args(): return parser + +def count_batches(dataset): + count = 0 + for _ in dataset: + count += 1 + return count + def run(args): if (not hasattr(args, "water_size")) or args.water_size == None: args.water_size = 100 @@ -40,7 +47,11 @@ def run(args): dirpath_input=args.input, batch_size=args.batch_size, ) - dataset_metadata = read_metadata(args.input) + + print("BATCHES_TRAIN", count_batches(dataset_train)) + print("BATCHES_VALIDATE", count_batches(dataset_validate)) + + # for (items, label) in dataset_train: # print("ITEMS", len(items), [ item.shape for item in items ]) @@ -59,4 +70,3 @@ def run(args): ) ai.train(dataset_train, dataset_validate) - \ No newline at end of file