mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-04 17:13:02 +00:00
do_argmax
This commit is contained in:
parent
6c09d5254d
commit
1f60f2a580
1 changed files with 3 additions and 2 deletions
|
@ -128,7 +128,7 @@ def do_png(args, ai, dataset, model_params):
|
|||
if i % 100 == 0:
|
||||
sys.stderr.write(f"Processed {i} items\r")
|
||||
|
||||
def do_jsonl(args, ai, dataset, model_params):
|
||||
def do_jsonl(args, ai, dataset, model_params, do_argmax=False):
|
||||
write_mode = "wt" if args.output.endswith(".gz") else "w"
|
||||
|
||||
handle = sys.stdout
|
||||
|
@ -156,7 +156,8 @@ def do_jsonl(args, ai, dataset, model_params):
|
|||
i_batch = 0
|
||||
for water_predict in water_predict_batch:
|
||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||
# water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||
if do_argmax:
|
||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||
# [ width, height ]
|
||||
water_actual = tf.squeeze(water_actual_batch[i_batch])
|
||||
|
||||
|
|
Loading…
Reference in a new issue