plot_multi: actually fix the plots

This commit is contained in:
Starbeamrainbowlabs 2023-07-05 16:35:26 +01:00
parent 9efc72db73
commit 636b316bfc
Signed by: sbrl
GPG Key ID: 1BE5172E637709C2
1 changed files with 17 additions and 9 deletions

View File

@ -4,6 +4,7 @@ import sys
import os
import re
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
@ -19,11 +20,11 @@ def do_regex(source, regex):
def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output):
i = 0
for train in train_list:
ax.plot(train, label=model_names[i])
ax.plot(train, label=model_names[i], linewidth=1)
i += 1
i = 0
for val in val_list:
ax.plot(val, label=f"val_{model_names[i]}")
ax.plot(val, label=f"val_{model_names[i]}", linewidth=1)
i += 1
ax.set_title(metric_name)
@ -40,8 +41,8 @@ def make_dfs(filepaths_input):
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
fig = plt.figure(figsize=(10*resolution, 13*resolution))
matplotlib.rcParams.update({'font.size': 15*resolution})
fig = plt.figure(figsize=(10*resolution, 14*resolution))
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
train = [ df[colname] for df in dfs ]
val = [ df[f"val_{colname}"] for df in dfs ]
@ -52,17 +53,24 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
ax = fig.add_subplot(3, 2, i+1)
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
# fig.tight_layout()
# Ref https://stackoverflow.com/a/57484812/1460422
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ]
fig.legend(lines, labels, loc='upper center', ncol=4)
legend = fig.legend(lines, labels, loc='outside upper center', ncol=3)
# Ref https://stackoverflow.com/a/48296983/1460422
# change the line width for the legend
for line in legend.get_lines():
line.set_linewidth(4.0*resolution)
fig.tight_layout()
plt.subplots_adjust(top=0.85)
target=os.path.join(dirpath_output, f"metrics.png")
plt.savefig(target)
plt.savefig(target, bbox_inches='tight')
sys.stderr.write(">>> Saved to ")
sys.stdout.write(target)