do_argmax

This commit is contained in:
Starbeamrainbowlabs 2022-11-24 18:11:03 +00:00
parent 6c09d5254d
commit 1f60f2a580
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -128,7 +128,7 @@ def do_png(args, ai, dataset, model_params):
if i % 100 == 0: if i % 100 == 0:
sys.stderr.write(f"Processed {i} items\r") sys.stderr.write(f"Processed {i} items\r")
def do_jsonl(args, ai, dataset, model_params): def do_jsonl(args, ai, dataset, model_params, do_argmax=False):
write_mode = "wt" if args.output.endswith(".gz") else "w" write_mode = "wt" if args.output.endswith(".gz") else "w"
handle = sys.stdout handle = sys.stdout
@ -156,7 +156,8 @@ def do_jsonl(args, ai, dataset, model_params):
i_batch = 0 i_batch = 0
for water_predict in water_predict_batch: for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ] # [ width, height, softmax_probabilities ] → [ batch, width, height ]
# water_predict = tf.math.argmax(water_predict, axis=-1) if do_argmax:
water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ] # [ width, height ]
water_actual = tf.squeeze(water_actual_batch[i_batch]) water_actual = tf.squeeze(water_actual_batch[i_batch])