plot_metrics_multi: add RESOLUTION env var

This commit is contained in:
Starbeamrainbowlabs 2023-06-14 15:50:37 +01:00
parent 18db54f0a7
commit 4bbc4c29c4
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -38,10 +38,10 @@ def make_dfs(filepaths_input):
print("DEBUG filepath_input", filepath_input) print("DEBUG filepath_input", filepath_input)
dfs = pd.read_csv(filepath_input, sep="\t") dfs = pd.read_csv(filepath_input, sep="\t")
def plot_metrics(filepaths_input, model_names, dirpath_output): def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ] dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
fig = plt.figure(figsize=(10,13)) fig = plt.figure(figsize=(10*resolution, 13*resolution))
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())): for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
train = [ df[colname] for df in dfs ] train = [ df[colname] for df in dfs ]
val = [ df[f"val_{colname}"] for df in dfs ] val = [ df[f"val_{colname}"] for df in dfs ]
@ -114,5 +114,10 @@ Usage:
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd() DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
plot_metrics(FILEPATHS_INPUT, MODEL_NAMES, DIRPATH_OUTPUT) plot_metrics(
FILEPATHS_INPUT,
MODEL_NAMES,
DIRPATH_OUTPUT,
resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1
)