dlr: plot all metrics

This commit is contained in:
Starbeamrainbowlabs 2023-03-22 17:41:34 +00:00
parent e565c36149
commit 6b17d45aad
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -182,13 +182,21 @@ else:
})
# ████████ ██████ █████ ██ ███ ██ ██ ███ ██ ██████
# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
# ██ ██████ ███████ ██ ██ ██ ██ ██ ██ ██ ██ ██ ███
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ██ ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
def plot_metric(train, val, name, dir_output):
plt.plot(train, label=f"train_{name}")
plt.plot(val, label=f"val_{name}")
plt.title(name)
plt.xlabel("epoch")
plt.ylabel(name)
plt.savefig(os.path.join(dir_output, f"{name}.png"))
plt.close()
if PATH_CHECKPOINT is None:
loss_fn = None
if LOSS == "cross-entropy-dice":
@ -234,34 +242,12 @@ if PATH_CHECKPOINT is None:
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
plt.plot(history.history["loss"])
plt.title("Training Loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "loss.png"))
plt.close()
plt.plot(history.history["accuracy"])
plt.title("Training Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "acc.png"))
plt.close()
plt.plot(history.history["val_loss"])
plt.title("Validation Loss")
plt.ylabel("val_loss")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_loss.png"))
plt.close()
plt.plot(history.history["val_accuracy"])
plt.title("Validation Accuracy")
plt.ylabel("val_accuracy")
plt.xlabel("epoch")
plt.savefig(os.path.join(DIR_OUTPUT, "val_acc.png"))
plt.close()
plot_metric(history.history["loss"], history.history["val_loss"], "loss")
plot_metric(history.history["acc"], history.history["val_acc"], "accuracy")
plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice")
plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou")
plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity")
plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity")
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████