mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 17:23:01 +00:00
train-predict fixup
This commit is contained in:
parent
42aea7a0cc
commit
c5b1501dba
1 changed files with 3 additions and 3 deletions
|
@ -88,7 +88,7 @@ def run(args):
|
||||||
sys.stderr.write(">>> Complete\n")
|
sys.stderr.write(">>> Complete\n")
|
||||||
|
|
||||||
def do_png(args, ai, dataset, model_code):
|
def do_png(args, ai, dataset, model_code):
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(os.path.dirname(args.output)):
|
||||||
os.mkdir(os.path.dirname(args.output))
|
os.mkdir(os.path.dirname(args.output))
|
||||||
|
|
||||||
model_params = json.loads(readfile(args.params))
|
model_params = json.loads(readfile(args.params))
|
||||||
|
@ -106,10 +106,10 @@ def do_png(args, ai, dataset, model_code):
|
||||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||||
# [ width, height ]
|
# [ width, height ]
|
||||||
water_item = water[i_batch]
|
water_actual = water[i_batch]
|
||||||
|
|
||||||
segmentation_plot(
|
segmentation_plot(
|
||||||
water_item, water_predict,
|
water_actual, water_predict,
|
||||||
model_code,
|
model_code,
|
||||||
args.output.replace("+d", str(i))
|
args.output.replace("+d", str(i))
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue