mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
plot_metrics_multi: add RESOLUTION env var
This commit is contained in:
parent
18db54f0a7
commit
4bbc4c29c4
1 changed files with 8 additions and 3 deletions
|
@ -38,10 +38,10 @@ def make_dfs(filepaths_input):
|
||||||
print("DEBUG filepath_input", filepath_input)
|
print("DEBUG filepath_input", filepath_input)
|
||||||
dfs = pd.read_csv(filepath_input, sep="\t")
|
dfs = pd.read_csv(filepath_input, sep="\t")
|
||||||
|
|
||||||
def plot_metrics(filepaths_input, model_names, dirpath_output):
|
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
||||||
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
||||||
|
|
||||||
fig = plt.figure(figsize=(10,13))
|
fig = plt.figure(figsize=(10*resolution, 13*resolution))
|
||||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
||||||
train = [ df[colname] for df in dfs ]
|
train = [ df[colname] for df in dfs ]
|
||||||
val = [ df[f"val_{colname}"] for df in dfs ]
|
val = [ df[f"val_{colname}"] for df in dfs ]
|
||||||
|
@ -114,5 +114,10 @@ Usage:
|
||||||
|
|
||||||
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
|
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
|
||||||
|
|
||||||
plot_metrics(FILEPATHS_INPUT, MODEL_NAMES, DIRPATH_OUTPUT)
|
plot_metrics(
|
||||||
|
FILEPATHS_INPUT,
|
||||||
|
MODEL_NAMES,
|
||||||
|
DIRPATH_OUTPUT,
|
||||||
|
resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue