From 636b316bfc49e27d54134d29c617c7401da32ccd Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 5 Jul 2023 16:35:26 +0100 Subject: [PATCH] plot_multi: actually fix the plots --- aimodel/src/plot_metrics_multi.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/aimodel/src/plot_metrics_multi.py b/aimodel/src/plot_metrics_multi.py index eefdb61..91cd93d 100755 --- a/aimodel/src/plot_metrics_multi.py +++ b/aimodel/src/plot_metrics_multi.py @@ -4,6 +4,7 @@ import sys import os import re import seaborn as sns +import matplotlib import matplotlib.pyplot as plt import pandas as pd @@ -19,11 +20,11 @@ def do_regex(source, regex): def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output): i = 0 for train in train_list: - ax.plot(train, label=model_names[i]) + ax.plot(train, label=model_names[i], linewidth=1) i += 1 i = 0 for val in val_list: - ax.plot(val, label=f"val_{model_names[i]}") + ax.plot(val, label=f"val_{model_names[i]}", linewidth=1) i += 1 ax.set_title(metric_name) @@ -40,8 +41,8 @@ def make_dfs(filepaths_input): def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1): dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ] - - fig = plt.figure(figsize=(10*resolution, 13*resolution)) + matplotlib.rcParams.update({'font.size': 15*resolution}) + fig = plt.figure(figsize=(10*resolution, 14*resolution)) for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())): train = [ df[colname] for df in dfs ] val = [ df[f"val_{colname}"] for df in dfs ] @@ -52,17 +53,24 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1): ax = fig.add_subplot(3, 2, i+1) plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output) - - # fig.tight_layout() + # Ref https://stackoverflow.com/a/57484812/1460422 # lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ] lines_labels = [ fig.axes[0].get_legend_handles_labels() ] lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ] - fig.legend(lines, labels, loc='upper center', ncol=4) - + legend = fig.legend(lines, labels, loc='outside upper center', ncol=3) + + # Ref https://stackoverflow.com/a/48296983/1460422 + # change the line width for the legend + for line in legend.get_lines(): + line.set_linewidth(4.0*resolution) + + fig.tight_layout() + plt.subplots_adjust(top=0.85) + target=os.path.join(dirpath_output, f"metrics.png") - plt.savefig(target) + plt.savefig(target, bbox_inches='tight') sys.stderr.write(">>> Saved to ") sys.stdout.write(target)