mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 06:05:01 +00:00
plot_multi: actually fix the plots
This commit is contained in:
parent
9efc72db73
commit
636b316bfc
1 changed files with 17 additions and 9 deletions
|
@ -4,6 +4,7 @@ import sys
|
|||
import os
|
||||
import re
|
||||
import seaborn as sns
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
@ -19,11 +20,11 @@ def do_regex(source, regex):
|
|||
def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output):
|
||||
i = 0
|
||||
for train in train_list:
|
||||
ax.plot(train, label=model_names[i])
|
||||
ax.plot(train, label=model_names[i], linewidth=1)
|
||||
i += 1
|
||||
i = 0
|
||||
for val in val_list:
|
||||
ax.plot(val, label=f"val_{model_names[i]}")
|
||||
ax.plot(val, label=f"val_{model_names[i]}", linewidth=1)
|
||||
i += 1
|
||||
|
||||
ax.set_title(metric_name)
|
||||
|
@ -40,8 +41,8 @@ def make_dfs(filepaths_input):
|
|||
|
||||
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
||||
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
||||
|
||||
fig = plt.figure(figsize=(10*resolution, 13*resolution))
|
||||
matplotlib.rcParams.update({'font.size': 15*resolution})
|
||||
fig = plt.figure(figsize=(10*resolution, 14*resolution))
|
||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
||||
train = [ df[colname] for df in dfs ]
|
||||
val = [ df[f"val_{colname}"] for df in dfs ]
|
||||
|
@ -52,17 +53,24 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
|||
ax = fig.add_subplot(3, 2, i+1)
|
||||
|
||||
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
|
||||
|
||||
# fig.tight_layout()
|
||||
|
||||
|
||||
# Ref https://stackoverflow.com/a/57484812/1460422
|
||||
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
|
||||
lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
|
||||
lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ]
|
||||
fig.legend(lines, labels, loc='upper center', ncol=4)
|
||||
|
||||
legend = fig.legend(lines, labels, loc='outside upper center', ncol=3)
|
||||
|
||||
# Ref https://stackoverflow.com/a/48296983/1460422
|
||||
# change the line width for the legend
|
||||
for line in legend.get_lines():
|
||||
line.set_linewidth(4.0*resolution)
|
||||
|
||||
fig.tight_layout()
|
||||
plt.subplots_adjust(top=0.85)
|
||||
|
||||
target=os.path.join(dirpath_output, f"metrics.png")
|
||||
plt.savefig(target)
|
||||
plt.savefig(target, bbox_inches='tight')
|
||||
|
||||
sys.stderr.write(">>> Saved to ")
|
||||
sys.stdout.write(target)
|
||||
|
|
Loading…
Reference in a new issue