mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
dlr predict: allow for multiple outputs
This commit is contained in:
parent
864dfa802d
commit
c0c6e81c01
1 changed files with 6 additions and 3 deletions
|
@ -256,14 +256,16 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
|
||||||
|
|
||||||
|
|
||||||
def plot_predictions(filepath, input_items, colormap, model):
|
def plot_predictions(filepath, input_items, colormap, model):
|
||||||
|
i = 0
|
||||||
for input_pair in input_items:
|
for input_pair in input_items:
|
||||||
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
||||||
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
||||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
||||||
|
|
||||||
|
print("DEBUG:plot_predictions INFER", prediction_mask)
|
||||||
|
|
||||||
plot_samples_matplotlib(
|
plot_samples_matplotlib(
|
||||||
filepath,
|
filepath.replace("$$", i),
|
||||||
[
|
[
|
||||||
# input_tensor,
|
# input_tensor,
|
||||||
input_pair[1], #label_colourmap
|
input_pair[1], #label_colourmap
|
||||||
|
@ -271,6 +273,7 @@ def plot_predictions(filepath, input_items, colormap, model):
|
||||||
],
|
],
|
||||||
figsize=(18, 14)
|
figsize=(18, 14)
|
||||||
)
|
)
|
||||||
|
i += 1
|
||||||
|
|
||||||
def get_from_batched(dataset, count):
|
def get_from_batched(dataset, count):
|
||||||
result = []
|
result = []
|
||||||
|
@ -284,13 +287,13 @@ def get_from_batched(dataset, count):
|
||||||
|
|
||||||
|
|
||||||
plot_predictions(
|
plot_predictions(
|
||||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
|
||||||
get_from_batched(dataset_train, 4),
|
get_from_batched(dataset_train, 4),
|
||||||
colormap,
|
colormap,
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
plot_predictions(
|
plot_predictions(
|
||||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
|
||||||
get_from_batched(dataset_validate, 4),
|
get_from_batched(dataset_validate, 4),
|
||||||
colormap,
|
colormap,
|
||||||
model=model
|
model=model
|
||||||
|
|
Loading…
Reference in a new issue