From fda40b05c9328e127482d54e88648e49e84b1664 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 20 Dec 2024 15:11:41 +0000 Subject: [PATCH] scripts/stbl-crossval: initial WIP draft --- aimodel/scripts/stbl-crossval.py | 76 ++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100755 aimodel/scripts/stbl-crossval.py diff --git a/aimodel/scripts/stbl-crossval.py b/aimodel/scripts/stbl-crossval.py new file mode 100755 index 0000000..6536ddc --- /dev/null +++ b/aimodel/scripts/stbl-crossval.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import os +import sys + +from loguru import logger +import pandas as pd + +# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them. +# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0. + + +if len(sys.argv) <= 1: + print(""" +Usage: + scripts/stbl-crossval.mjs {{path/to/directory}} + +...in which the given directory contains a series of experiment root directories to include in the statistical analysis. + +This script is not picky about the format of the data in metrics.tsv, so long as it's in the form: + +epoch metric_A metric_B … +0 val:float val:float … +1 val:float val:float … +2 val:float val:float … +⋮ +""") + sys.exit(0) + +DIRPATH = sys.argv[1] # [0] == script path + +files = 0 +metrics = {} + +for filepath in os.scandir(DIRPATH): + tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t") + + # metrics.append(tbl) + + for column in tbl.columns: + if column == "epoch": + continue # Row index implicitly retains this + + if column not in metrics: + metrics[column] = [] + + metrics[column].append(tbl[column].values) + # print(column, tbl[column]) + + # print("DEBUG:metrics", tbl) + files += 1 + +logger.info(f"Read {files} files into crossval-stbl{files} analysis") + +stats = {} + +for metric in metrics.keys(): + metrics[metric] = pd.DataFrame(metrics[metric]).transpose() + + if metric not in stats: + stats[metric] = {} + + stats[metric]["aad"] = metrics[metric].max(axis=1) # mean/average absolute deviation + stats[metric]["mad"] = metrics[metric].max(axis=1) # median absolute deviation + stats[metric]["stddev"] = metrics[metric].std(axis=1) + stats[metric]["mean"] = metrics[metric].mean(axis=1) + stats[metric]["min"] = metrics[metric].min(axis=1) + stats[metric]["max"] = metrics[metric].max(axis=1) + stats[metric]["agg_min"] = stats[metric]["min"].min() + stats[metric]["agg_max"] = stats[metric]["max"].max() + stats[metric]["agg_stddev"] = metrics[metric].stack().std() + stats[metric]["agg_mean"] = metrics[metric].stack().std() + stats[metric]["agg_aad"] = metrics[metric].stack().max() # mean/average absolute deviation + + print(stats[metric]) +