train-predict fixup

This commit is contained in:
Starbeamrainbowlabs 2022-10-21 15:27:39 +01:00
parent 42aea7a0cc
commit c5b1501dba
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -88,7 +88,7 @@ def run(args):
sys.stderr.write(">>> Complete\n")
def do_png(args, ai, dataset, model_code):
if not os.path.exists(args.output):
if not os.path.exists(os.path.dirname(args.output)):
os.mkdir(os.path.dirname(args.output))
model_params = json.loads(readfile(args.params))
@ -106,10 +106,10 @@ def do_png(args, ai, dataset, model_code):
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ]
water_item = water[i_batch]
water_actual = water[i_batch]
segmentation_plot(
water_item, water_predict,
water_actual, water_predict,
model_code,
args.output.replace("+d", str(i))
)