diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 71d127c..a5333ea 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -305,6 +305,11 @@ def plot_samples_matplotlib(filepath, display_list): plt.colorbar() plt.savefig(filepath, dpi=200) +def save_samples(filepath, save_list): + handle = io.open(filepath, "a") + json.dump(save_list, handle) + handle.write("\n") + handle.close() def plot_predictions(filepath, input_items, colormap, model): i = 0 @@ -327,6 +332,11 @@ def plot_predictions(filepath, input_items, colormap, model): prediction_colormap ] ) + + save_samples( + filepath.replace("_$$", "").replace(".png", ".jsonl"), + prediction_mask + ) i += 1 def get_from_batched(dataset, count):