mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 06:05:01 +00:00
scripts/crossval-stbl: finish off the script
TODO switch ou median absolute distance for something else when Nina replies
This commit is contained in:
parent
fda40b05c9
commit
5a412ddc26
2 changed files with 119 additions and 76 deletions
119
aimodel/scripts/crossval-stbl.py
Executable file
119
aimodel/scripts/crossval-stbl.py
Executable file
|
@ -0,0 +1,119 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import pandas as pd
|
||||||
|
import scipy
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them.
|
||||||
|
# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0.
|
||||||
|
|
||||||
|
|
||||||
|
if len(sys.argv) <= 1:
|
||||||
|
print("""
|
||||||
|
Usage:
|
||||||
|
scripts/stbl-crossval.mjs {{path/to/directory}}
|
||||||
|
|
||||||
|
...in which the given directory contains a series of experiment root directories to include in the statistical analysis.
|
||||||
|
|
||||||
|
This script is not picky about the format of the data in metrics.tsv, so long as it's in the form:
|
||||||
|
|
||||||
|
epoch metric_A metric_B …
|
||||||
|
0 val:float val:float …
|
||||||
|
1 val:float val:float …
|
||||||
|
2 val:float val:float …
|
||||||
|
⋮
|
||||||
|
""")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
DIRPATH = sys.argv[1] # [0] == script path
|
||||||
|
|
||||||
|
files = 0
|
||||||
|
metrics = {}
|
||||||
|
epochs = None
|
||||||
|
|
||||||
|
for filepath in os.scandir(DIRPATH):
|
||||||
|
if not os.path.isdir(filepath):
|
||||||
|
continue
|
||||||
|
tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t")
|
||||||
|
|
||||||
|
# metrics.append(tbl)
|
||||||
|
|
||||||
|
for column in tbl.columns:
|
||||||
|
if column == "epoch":
|
||||||
|
if epochs is None:
|
||||||
|
epochs = tbl[column]
|
||||||
|
continue # Row index implicitly retains this
|
||||||
|
|
||||||
|
if column not in metrics:
|
||||||
|
metrics[column] = []
|
||||||
|
|
||||||
|
metrics[column].append(tbl[column].values)
|
||||||
|
# print(column, tbl[column])
|
||||||
|
|
||||||
|
# print("DEBUG:metrics", tbl)
|
||||||
|
files += 1
|
||||||
|
|
||||||
|
logger.info(f"Read {files} files into crossval-stbl{files} analysis")
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
for metric in metrics.keys():
|
||||||
|
metrics[metric] = pd.DataFrame(metrics[metric]).transpose()
|
||||||
|
|
||||||
|
if metric not in stats:
|
||||||
|
stats[metric] = {}
|
||||||
|
|
||||||
|
stats[metric]["mad"] = scipy.stats.median_abs_deviation(
|
||||||
|
metrics[metric], axis=1
|
||||||
|
) # median absolute deviation
|
||||||
|
stats[metric]["stddev"] = metrics[metric].std(axis=1)
|
||||||
|
stats[metric]["mean"] = metrics[metric].mean(axis=1)
|
||||||
|
stats[metric]["min"] = metrics[metric].min(axis=1)
|
||||||
|
stats[metric]["max"] = metrics[metric].max(axis=1)
|
||||||
|
stats[metric]["agg_min"] = stats[metric]["min"].min()
|
||||||
|
stats[metric]["agg_max"] = stats[metric]["max"].max()
|
||||||
|
stats[metric]["agg_stddev"] = metrics[metric].stack().std()
|
||||||
|
stats[metric]["agg_mean"] = metrics[metric].stack().std()
|
||||||
|
stats[metric]["agg_mad"] = scipy.stats.median_abs_deviation(
|
||||||
|
metrics[metric].stack()
|
||||||
|
) # median absolute deviation
|
||||||
|
|
||||||
|
# print(stats[metric])
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
plt.ylim(min(0, stats[metric]["agg_min"]), max(1, stats[metric]["agg_max"]))
|
||||||
|
plt.grid(visible=True, which="major", axis="y", linewidth=2)
|
||||||
|
plt.grid(visible=True, which="minor", axis="y", linewidth=1)
|
||||||
|
plt.minorticks_on()
|
||||||
|
plt.fill_between(
|
||||||
|
epochs,
|
||||||
|
stats[metric]["min"],
|
||||||
|
stats[metric]["max"],
|
||||||
|
alpha=0.2,
|
||||||
|
facecolor="#B7DE28",
|
||||||
|
edgecolor="#FDE724",
|
||||||
|
linestyle="dotted",
|
||||||
|
linewidth=1,
|
||||||
|
)
|
||||||
|
plt.fill_between(
|
||||||
|
epochs,
|
||||||
|
stats[metric]["mean"] - stats[metric]["mad"],
|
||||||
|
stats[metric]["mean"] + stats[metric]["mad"],
|
||||||
|
alpha=0.5,
|
||||||
|
facecolor="#228A8D",
|
||||||
|
edgecolor="#3CBB74",
|
||||||
|
linestyle="dashed",
|
||||||
|
linewidth=1,
|
||||||
|
)
|
||||||
|
plt.plot(epochs, stats[metric]["mean"], color="#450C54")
|
||||||
|
plt.title(f"{metric} // crossval-stbl{files}")
|
||||||
|
plt.xlabel("epoch")
|
||||||
|
plt.ylabel(metric)
|
||||||
|
plt.savefig(os.path.join(DIRPATH, f"crossval-stbl{files}_{metric}.png"))
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logger.success(f"Written {len(stats.keys())} graphs to {DIRPATH}")
|
|
@ -1,76 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them.
|
|
||||||
# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0.
|
|
||||||
|
|
||||||
|
|
||||||
if len(sys.argv) <= 1:
|
|
||||||
print("""
|
|
||||||
Usage:
|
|
||||||
scripts/stbl-crossval.mjs {{path/to/directory}}
|
|
||||||
|
|
||||||
...in which the given directory contains a series of experiment root directories to include in the statistical analysis.
|
|
||||||
|
|
||||||
This script is not picky about the format of the data in metrics.tsv, so long as it's in the form:
|
|
||||||
|
|
||||||
epoch metric_A metric_B …
|
|
||||||
0 val:float val:float …
|
|
||||||
1 val:float val:float …
|
|
||||||
2 val:float val:float …
|
|
||||||
⋮
|
|
||||||
""")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
DIRPATH = sys.argv[1] # [0] == script path
|
|
||||||
|
|
||||||
files = 0
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
for filepath in os.scandir(DIRPATH):
|
|
||||||
tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t")
|
|
||||||
|
|
||||||
# metrics.append(tbl)
|
|
||||||
|
|
||||||
for column in tbl.columns:
|
|
||||||
if column == "epoch":
|
|
||||||
continue # Row index implicitly retains this
|
|
||||||
|
|
||||||
if column not in metrics:
|
|
||||||
metrics[column] = []
|
|
||||||
|
|
||||||
metrics[column].append(tbl[column].values)
|
|
||||||
# print(column, tbl[column])
|
|
||||||
|
|
||||||
# print("DEBUG:metrics", tbl)
|
|
||||||
files += 1
|
|
||||||
|
|
||||||
logger.info(f"Read {files} files into crossval-stbl{files} analysis")
|
|
||||||
|
|
||||||
stats = {}
|
|
||||||
|
|
||||||
for metric in metrics.keys():
|
|
||||||
metrics[metric] = pd.DataFrame(metrics[metric]).transpose()
|
|
||||||
|
|
||||||
if metric not in stats:
|
|
||||||
stats[metric] = {}
|
|
||||||
|
|
||||||
stats[metric]["aad"] = metrics[metric].max(axis=1) # mean/average absolute deviation
|
|
||||||
stats[metric]["mad"] = metrics[metric].max(axis=1) # median absolute deviation
|
|
||||||
stats[metric]["stddev"] = metrics[metric].std(axis=1)
|
|
||||||
stats[metric]["mean"] = metrics[metric].mean(axis=1)
|
|
||||||
stats[metric]["min"] = metrics[metric].min(axis=1)
|
|
||||||
stats[metric]["max"] = metrics[metric].max(axis=1)
|
|
||||||
stats[metric]["agg_min"] = stats[metric]["min"].min()
|
|
||||||
stats[metric]["agg_max"] = stats[metric]["max"].max()
|
|
||||||
stats[metric]["agg_stddev"] = metrics[metric].stack().std()
|
|
||||||
stats[metric]["agg_mean"] = metrics[metric].stack().std()
|
|
||||||
stats[metric]["agg_aad"] = metrics[metric].stack().max() # mean/average absolute deviation
|
|
||||||
|
|
||||||
print(stats[metric])
|
|
||||||
|
|
Loading…
Reference in a new issue