mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-21 17:03:00 +00:00
dlr predict: allow for multiple outputs
This commit is contained in:
parent
864dfa802d
commit
c0c6e81c01
1 changed files with 6 additions and 3 deletions
|
@ -256,14 +256,16 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
|
|||
|
||||
|
||||
def plot_predictions(filepath, input_items, colormap, model):
|
||||
i = 0
|
||||
for input_pair in input_items:
|
||||
prediction_mask = infer(image_tensor=input_pair[0], model=model)
|
||||
# label_colourmap = decode_segmentation_masks(input_pair[1], colormap, 2)
|
||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 2)
|
||||
|
||||
print("DEBUG:plot_predictions INFER", prediction_mask)
|
||||
|
||||
plot_samples_matplotlib(
|
||||
filepath,
|
||||
filepath.replace("$$", i),
|
||||
[
|
||||
# input_tensor,
|
||||
input_pair[1], #label_colourmap
|
||||
|
@ -271,6 +273,7 @@ def plot_predictions(filepath, input_items, colormap, model):
|
|||
],
|
||||
figsize=(18, 14)
|
||||
)
|
||||
i += 1
|
||||
|
||||
def get_from_batched(dataset, count):
|
||||
result = []
|
||||
|
@ -284,13 +287,13 @@ def get_from_batched(dataset, count):
|
|||
|
||||
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||
os.path.join(DIR_OUTPUT, "predict_train_$$.png"),
|
||||
get_from_batched(dataset_train, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||
os.path.join(DIR_OUTPUT, "predict_validate_$$.png"),
|
||||
get_from_batched(dataset_validate, 4),
|
||||
colormap,
|
||||
model=model
|
||||
|
|
Loading…
Reference in a new issue