mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-21 17:03:00 +00:00
plot_metrics_multi: add RESOLUTION env var
This commit is contained in:
parent
18db54f0a7
commit
4bbc4c29c4
1 changed files with 8 additions and 3 deletions
|
@ -38,10 +38,10 @@ def make_dfs(filepaths_input):
|
|||
print("DEBUG filepath_input", filepath_input)
|
||||
dfs = pd.read_csv(filepath_input, sep="\t")
|
||||
|
||||
def plot_metrics(filepaths_input, model_names, dirpath_output):
|
||||
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
||||
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
||||
|
||||
fig = plt.figure(figsize=(10,13))
|
||||
fig = plt.figure(figsize=(10*resolution, 13*resolution))
|
||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
||||
train = [ df[colname] for df in dfs ]
|
||||
val = [ df[f"val_{colname}"] for df in dfs ]
|
||||
|
@ -114,5 +114,10 @@ Usage:
|
|||
|
||||
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
|
||||
|
||||
plot_metrics(FILEPATHS_INPUT, MODEL_NAMES, DIRPATH_OUTPUT)
|
||||
plot_metrics(
|
||||
FILEPATHS_INPUT,
|
||||
MODEL_NAMES,
|
||||
DIRPATH_OUTPUT,
|
||||
resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue