mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-21 17:03:00 +00:00
plot_multi: actually fix the plots
This commit is contained in:
parent
9efc72db73
commit
636b316bfc
1 changed files with 17 additions and 9 deletions
|
@ -4,6 +4,7 @@ import sys
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
@ -19,11 +20,11 @@ def do_regex(source, regex):
|
||||||
def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output):
|
def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output):
|
||||||
i = 0
|
i = 0
|
||||||
for train in train_list:
|
for train in train_list:
|
||||||
ax.plot(train, label=model_names[i])
|
ax.plot(train, label=model_names[i], linewidth=1)
|
||||||
i += 1
|
i += 1
|
||||||
i = 0
|
i = 0
|
||||||
for val in val_list:
|
for val in val_list:
|
||||||
ax.plot(val, label=f"val_{model_names[i]}")
|
ax.plot(val, label=f"val_{model_names[i]}", linewidth=1)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
ax.set_title(metric_name)
|
ax.set_title(metric_name)
|
||||||
|
@ -40,8 +41,8 @@ def make_dfs(filepaths_input):
|
||||||
|
|
||||||
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
||||||
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
||||||
|
matplotlib.rcParams.update({'font.size': 15*resolution})
|
||||||
fig = plt.figure(figsize=(10*resolution, 13*resolution))
|
fig = plt.figure(figsize=(10*resolution, 14*resolution))
|
||||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
||||||
train = [ df[colname] for df in dfs ]
|
train = [ df[colname] for df in dfs ]
|
||||||
val = [ df[f"val_{colname}"] for df in dfs ]
|
val = [ df[f"val_{colname}"] for df in dfs ]
|
||||||
|
@ -52,17 +53,24 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
||||||
ax = fig.add_subplot(3, 2, i+1)
|
ax = fig.add_subplot(3, 2, i+1)
|
||||||
|
|
||||||
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
|
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
|
||||||
|
|
||||||
# fig.tight_layout()
|
|
||||||
|
|
||||||
# Ref https://stackoverflow.com/a/57484812/1460422
|
# Ref https://stackoverflow.com/a/57484812/1460422
|
||||||
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
|
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
|
||||||
lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
|
lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
|
||||||
lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ]
|
lines, labels = [sum(lol, []) for lol in zip(*lines_labels) ]
|
||||||
fig.legend(lines, labels, loc='upper center', ncol=4)
|
legend = fig.legend(lines, labels, loc='outside upper center', ncol=3)
|
||||||
|
|
||||||
|
# Ref https://stackoverflow.com/a/48296983/1460422
|
||||||
|
# change the line width for the legend
|
||||||
|
for line in legend.get_lines():
|
||||||
|
line.set_linewidth(4.0*resolution)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
plt.subplots_adjust(top=0.85)
|
||||||
|
|
||||||
target=os.path.join(dirpath_output, f"metrics.png")
|
target=os.path.join(dirpath_output, f"metrics.png")
|
||||||
plt.savefig(target)
|
plt.savefig(target, bbox_inches='tight')
|
||||||
|
|
||||||
sys.stderr.write(">>> Saved to ")
|
sys.stderr.write(">>> Saved to ")
|
||||||
sys.stdout.write(target)
|
sys.stdout.write(target)
|
||||||
|
|
Loading…
Reference in a new issue