From 4bbc4c29c423254afea22824e83cc72405b60dfc Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Wed, 14 Jun 2023 15:50:37 +0100 Subject: [PATCH] plot_metrics_multi: add RESOLUTION env var --- aimodel/src/plot_metrics_multi.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/aimodel/src/plot_metrics_multi.py b/aimodel/src/plot_metrics_multi.py index 7d1b6c9..eefdb61 100755 --- a/aimodel/src/plot_metrics_multi.py +++ b/aimodel/src/plot_metrics_multi.py @@ -38,10 +38,10 @@ def make_dfs(filepaths_input): print("DEBUG filepath_input", filepath_input) dfs = pd.read_csv(filepath_input, sep="\t") -def plot_metrics(filepaths_input, model_names, dirpath_output): +def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1): dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ] - fig = plt.figure(figsize=(10,13)) + fig = plt.figure(figsize=(10*resolution, 13*resolution)) for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())): train = [ df[colname] for df in dfs ] val = [ df[f"val_{colname}"] for df in dfs ] @@ -114,5 +114,10 @@ Usage: DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd() - plot_metrics(FILEPATHS_INPUT, MODEL_NAMES, DIRPATH_OUTPUT) + plot_metrics( + FILEPATHS_INPUT, + MODEL_NAMES, + DIRPATH_OUTPUT, + resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1 + )