mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 14:15:01 +00:00
scripts/stbl-crossval: initial WIP draft
This commit is contained in:
parent
f8a1e1b594
commit
fda40b05c9
1 changed files with 76 additions and 0 deletions
76
aimodel/scripts/stbl-crossval.py
Executable file
76
aimodel/scripts/stbl-crossval.py
Executable file
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
import pandas as pd
|
||||
|
||||
# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them.
|
||||
# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0.
|
||||
|
||||
|
||||
if len(sys.argv) <= 1:
|
||||
print("""
|
||||
Usage:
|
||||
scripts/stbl-crossval.mjs {{path/to/directory}}
|
||||
|
||||
...in which the given directory contains a series of experiment root directories to include in the statistical analysis.
|
||||
|
||||
This script is not picky about the format of the data in metrics.tsv, so long as it's in the form:
|
||||
|
||||
epoch metric_A metric_B …
|
||||
0 val:float val:float …
|
||||
1 val:float val:float …
|
||||
2 val:float val:float …
|
||||
⋮
|
||||
""")
|
||||
sys.exit(0)
|
||||
|
||||
DIRPATH = sys.argv[1] # [0] == script path
|
||||
|
||||
files = 0
|
||||
metrics = {}
|
||||
|
||||
for filepath in os.scandir(DIRPATH):
|
||||
tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t")
|
||||
|
||||
# metrics.append(tbl)
|
||||
|
||||
for column in tbl.columns:
|
||||
if column == "epoch":
|
||||
continue # Row index implicitly retains this
|
||||
|
||||
if column not in metrics:
|
||||
metrics[column] = []
|
||||
|
||||
metrics[column].append(tbl[column].values)
|
||||
# print(column, tbl[column])
|
||||
|
||||
# print("DEBUG:metrics", tbl)
|
||||
files += 1
|
||||
|
||||
logger.info(f"Read {files} files into crossval-stbl{files} analysis")
|
||||
|
||||
stats = {}
|
||||
|
||||
for metric in metrics.keys():
|
||||
metrics[metric] = pd.DataFrame(metrics[metric]).transpose()
|
||||
|
||||
if metric not in stats:
|
||||
stats[metric] = {}
|
||||
|
||||
stats[metric]["aad"] = metrics[metric].max(axis=1) # mean/average absolute deviation
|
||||
stats[metric]["mad"] = metrics[metric].max(axis=1) # median absolute deviation
|
||||
stats[metric]["stddev"] = metrics[metric].std(axis=1)
|
||||
stats[metric]["mean"] = metrics[metric].mean(axis=1)
|
||||
stats[metric]["min"] = metrics[metric].min(axis=1)
|
||||
stats[metric]["max"] = metrics[metric].max(axis=1)
|
||||
stats[metric]["agg_min"] = stats[metric]["min"].min()
|
||||
stats[metric]["agg_max"] = stats[metric]["max"].max()
|
||||
stats[metric]["agg_stddev"] = metrics[metric].stack().std()
|
||||
stats[metric]["agg_mean"] = metrics[metric].stack().std()
|
||||
stats[metric]["agg_aad"] = metrics[metric].stack().max() # mean/average absolute deviation
|
||||
|
||||
print(stats[metric])
|
||||
|
Loading…
Reference in a new issue