plot_multi: actually fix the plots

This commit is contained in:
Starbeamrainbowlabs 2023-07-05 16:35:26 +01:00
parent 9efc72db73
commit 636b316bfc
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -4,6 +4,7 @@ import sys
import os import os
import re import re
import seaborn as sns import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
@ -19,11 +20,11 @@ def do_regex(source, regex):
def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output): def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output):
i = 0 i = 0
for train in train_list: for train in train_list:
ax.plot(train, label=model_names[i]) ax.plot(train, label=model_names[i], linewidth=1)
i += 1 i += 1
i = 0 i = 0
for val in val_list: for val in val_list:
ax.plot(val, label=f"val_{model_names[i]}") ax.plot(val, label=f"val_{model_names[i]}", linewidth=1)
i += 1 i += 1
ax.set_title(metric_name) ax.set_title(metric_name)
@ -40,8 +41,8 @@ def make_dfs(filepaths_input):
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1): def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ] dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
matplotlib.rcParams.update({'font.size': 15*resolution})
fig = plt.figure(figsize=(10*resolution, 13*resolution)) fig = plt.figure(figsize=(10*resolution, 14*resolution))
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())): for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
train = [ df[colname] for df in dfs ] train = [ df[colname] for df in dfs ]
val = [ df[f"val_{colname}"] for df in dfs ] val = [ df[f"val_{colname}"] for df in dfs ]
@ -52,17 +53,24 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
ax = fig.add_subplot(3, 2, i+1) ax = fig.add_subplot(3, 2, i+1)
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output) plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
# fig.tight_layout()
# Ref https://stackoverflow.com/a/57484812/1460422 # Ref https://stackoverflow.com/a/57484812/1460422
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ] # lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
lines_labels = [ fig.axes[0].get_legend_handles_labels() ] lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ] lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ]
fig.legend(lines, labels, loc='upper center', ncol=4) legend = fig.legend(lines, labels, loc='outside upper center', ncol=3)
# Ref https://stackoverflow.com/a/48296983/1460422
# change the line width for the legend
for line in legend.get_lines():
line.set_linewidth(4.0*resolution)
fig.tight_layout()
plt.subplots_adjust(top=0.85)
target=os.path.join(dirpath_output, f"metrics.png") target=os.path.join(dirpath_output, f"metrics.png")
plt.savefig(target) plt.savefig(target, bbox_inches='tight')
sys.stderr.write(">>> Saved to ") sys.stderr.write(">>> Saved to ")
sys.stdout.write(target) sys.stdout.write(target)