train-predict fixup

This commit is contained in:
Starbeamrainbowlabs 2022-10-21 15:27:39 +01:00
parent 42aea7a0cc
commit c5b1501dba
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -88,7 +88,7 @@ def run(args):
sys.stderr.write(">>> Complete\n") sys.stderr.write(">>> Complete\n")
def do_png(args, ai, dataset, model_code): def do_png(args, ai, dataset, model_code):
if not os.path.exists(args.output): if not os.path.exists(os.path.dirname(args.output)):
os.mkdir(os.path.dirname(args.output)) os.mkdir(os.path.dirname(args.output))
model_params = json.loads(readfile(args.params)) model_params = json.loads(readfile(args.params))
@ -106,10 +106,10 @@ def do_png(args, ai, dataset, model_code):
# [ width, height, softmax_probabilities ] → [ batch, width, height ] # [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1) water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ] # [ width, height ]
water_item = water[i_batch] water_actual = water[i_batch]
segmentation_plot( segmentation_plot(
water_item, water_predict, water_actual, water_predict,
model_code, model_code,
args.output.replace("+d", str(i)) args.output.replace("+d", str(i))
) )