From 6b17d45aad83ac23e12ca957ee78e77ce46fd4b0 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 22 Mar 2023 17:41:34 +0000 Subject: [PATCH] dlr: plot all metrics --- aimodel/src/deeplabv3_plus_test_rainfall.py | 48 ++++++++------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 10912a9..0f89e68 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -182,13 +182,21 @@ else: }) - # ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████ # ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██ # ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███ # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████ +def plot_metric(train, val, name, dir_output): + plt.plot(train, label=f"train_{name}") + plt.plot(val, label=f"val_{name}") + plt.title(name) + plt.xlabel("epoch") + plt.ylabel(name) + plt.savefig(os.path.join(dir_output, f"{name}.png")) + plt.close() + if PATH_CHECKPOINT is None: loss_fn = None if LOSS == "cross-entropy-dice": @@ -233,36 +241,14 @@ if PATH_CHECKPOINT is None: ) logger.info(">>> Training complete") logger.info(">>> Plotting graphs") - - plt.plot(history.history["loss"]) - plt.title("Training Loss") - plt.ylabel("loss") - plt.xlabel("epoch") - plt.savefig(os.path.join(DIR_OUTPUT, "loss.png")) - plt.close() - - plt.plot(history.history["accuracy"]) - plt.title("Training Accuracy") - plt.ylabel("accuracy") - plt.xlabel("epoch") - plt.savefig(os.path.join(DIR_OUTPUT, "acc.png")) - plt.close() - - plt.plot(history.history["val_loss"]) - plt.title("Validation Loss") - plt.ylabel("val_loss") - plt.xlabel("epoch") - plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png")) - plt.close() - - plt.plot(history.history["val_accuracy"]) - plt.title("Validation Accuracy") - plt.ylabel("val_accuracy") - plt.xlabel("epoch") - plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png")) - plt.close() - - + + plot_metric(history.history["loss"], history.history["val_loss"], "loss") + plot_metric(history.history["acc"], history.history["val_acc"], "accuracy") + plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice") + plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou") + plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity") + plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity") + # ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████ # ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██