dlr: plot all metrics

This commit is contained in:
Starbeamrainbowlabs 2023-03-22 17:41:34 +00:00
parent e565c36149
commit 6b17d45aad
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -182,13 +182,21 @@ else:
}) })
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████ # ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██ # ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███ # ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████ # ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
def plot_metric(train, val, name, dir_output):
plt.plot(train, label=f"train_{name}")
plt.plot(val, label=f"val_{name}")
plt.title(name)
plt.xlabel("epoch")
plt.ylabel(name)
plt.savefig(os.path.join(dir_output, f"{name}.png"))
plt.close()
if PATH_CHECKPOINT is None: if PATH_CHECKPOINT is None:
loss_fn = None loss_fn = None
if LOSS == "cross-entropy-dice": if LOSS == "cross-entropy-dice":
@ -233,36 +241,14 @@ if PATH_CHECKPOINT is None:
) )
logger.info(">>> Training complete") logger.info(">>> Training complete")
logger.info(">>> Plotting graphs") logger.info(">>> Plotting graphs")
plt.plot(history.history["loss"]) plot_metric(history.history["loss"], history.history["val_loss"], "loss")
plt.title("Training Loss") plot_metric(history.history["acc"], history.history["val_acc"], "accuracy")
plt.ylabel("loss") plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice")
plt.xlabel("epoch") plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou")
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png")) plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity")
plt.close() plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity")
plt.plot(history.history["accuracy"])
plt.title("Training Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
plt.close()
plt.plot(history.history["val_loss"])
plt.title("Validation Loss")
plt.ylabel("val_loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
plt.close()
plt.plot(history.history["val_accuracy"])
plt.title("Validation Accuracy")
plt.ylabel("val_accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
plt.close()
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████ # ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
# ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██ # ██ ████ ██ ██ ██ ██ ██ ██ ████ ██ ██ ██