From e44b5533b11094a4eec1f096e2d58dacb421f187 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 20 Jul 2023 15:49:09 +0100 Subject: [PATCH] plot_metrics_multi: add option to plot train/val separately --- aimodel/src/plot_metrics_multi.py | 33 +++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/aimodel/src/plot_metrics_multi.py b/aimodel/src/plot_metrics_multi.py index 91cd93d..e70a5f7 100755 --- a/aimodel/src/plot_metrics_multi.py +++ b/aimodel/src/plot_metrics_multi.py @@ -33,15 +33,29 @@ def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output): # plt.savefig(os.path.join(dir_output, f"{name}.png")) # plt.close() +def plot_metric_single(ax, data_list, metric_name, model_names, dir_output): + i = 0 + for item in data_list: + ax.plot(item, label=model_names[i], linewidth=1) + i += 1 + + ax.set_title(metric_name) + ax.set_xlabel("epoch") + ax.set_ylabel(metric_name) + # plt.savefig(os.path.join(dir_output, f"{name}.png")) + # plt.close() + + def make_dfs(filepaths_input): dfs = [] for filepath_input in filepaths_input: print("DEBUG filepath_input", filepath_input) dfs = pd.read_csv(filepath_input, sep="\t") -def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1): + +def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1, train_val_separate=False): dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ] - matplotlib.rcParams.update({'font.size': 15*resolution}) + matplotlib.rcParams.update({'font.size': 15*resolution*(0.5 if train_val_separate else 1)}) fig = plt.figure(figsize=(10*resolution, 14*resolution)) for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())): train = [ df[colname] for df in dfs ] @@ -50,9 +64,15 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1): colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \ .replace("one_hot_mean_iou", "mean iou") - ax = fig.add_subplot(3, 2, i+1) + if train_val_separate: + ax = fig.add_subplot(3*2, 2*2, (i+1)*2 - 1) + plot_metric_single(ax, train, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output) + ax = fig.add_subplot(3*2, 2*2, (i+1)*2) + plot_metric_single(ax, val, metric_name=f"val_{colname_display}", model_names=model_names, dir_output=dirpath_output) + else: + ax = fig.add_subplot(3 * 2, 2, i+1) + plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output) - plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output) # Ref https://stackoverflow.com/a/57484812/1460422 @@ -91,7 +111,7 @@ Usage: echo -e "filepathA\\nfilepathB..." | [OUTPUT="path/to/output_dir"] [REGEX_NAME=''] path/to/plot_metrics_multi.py """) sys.exit() - + TRAIN_VAL_SEPARATE = True if "TRAIN_VAL_SEPARATE" in os.environ else False REGEX_NAME = os.environ["REGEX_NAME"] if "REGEX_NAME" in os.environ else None if REGEX_NAME is None and len(sys.argv) >= 1: REGEX_NAME = sys.argv[1] @@ -126,6 +146,7 @@ Usage: FILEPATHS_INPUT, MODEL_NAMES, DIRPATH_OUTPUT, - resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1 + resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1, + train_val_separate=TRAIN_VAL_SEPARATE )