scripts/stbl-crossval: initial WIP draft

This commit is contained in:
Starbeamrainbowlabs 2024-12-20 15:11:41 +00:00
parent f8a1e1b594
commit fda40b05c9
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -0,0 +1,76 @@
#!/usr/bin/env python3
import os
import sys
from loguru import logger
import pandas as pd
# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them.
# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0.
if len(sys.argv) <= 1:
print("""
Usage:
scripts/stbl-crossval.mjs {{path/to/directory}}
...in which the given directory contains a series of experiment root directories to include in the statistical analysis.
This script is not picky about the format of the data in metrics.tsv, so long as it's in the form:
epoch metric_A metric_B
0 val:float val:float
1 val:float val:float
2 val:float val:float
""")
sys.exit(0)
DIRPATH = sys.argv[1] # [0] == script path
files = 0
metrics = {}
for filepath in os.scandir(DIRPATH):
tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t")
# metrics.append(tbl)
for column in tbl.columns:
if column == "epoch":
continue # Row index implicitly retains this
if column not in metrics:
metrics[column] = []
metrics[column].append(tbl[column].values)
# print(column, tbl[column])
# print("DEBUG:metrics", tbl)
files += 1
logger.info(f"Read {files} files into crossval-stbl{files} analysis")
stats = {}
for metric in metrics.keys():
metrics[metric] = pd.DataFrame(metrics[metric]).transpose()
if metric not in stats:
stats[metric] = {}
stats[metric]["aad"] = metrics[metric].max(axis=1) # mean/average absolute deviation
stats[metric]["mad"] = metrics[metric].max(axis=1) # median absolute deviation
stats[metric]["stddev"] = metrics[metric].std(axis=1)
stats[metric]["mean"] = metrics[metric].mean(axis=1)
stats[metric]["min"] = metrics[metric].min(axis=1)
stats[metric]["max"] = metrics[metric].max(axis=1)
stats[metric]["agg_min"] = stats[metric]["min"].min()
stats[metric]["agg_max"] = stats[metric]["max"].max()
stats[metric]["agg_stddev"] = metrics[metric].stack().std()
stats[metric]["agg_mean"] = metrics[metric].stack().std()
stats[metric]["agg_aad"] = metrics[metric].stack().max() # mean/average absolute deviation
print(stats[metric])