From 5a412ddc26a8a36c9687ddc68189d61bc9fc0b13 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 20 Dec 2024 18:37:18 +0000 Subject: [PATCH] scripts/crossval-stbl: finish off the script TODO switch ou median absolute distance for something else when Nina replies --- aimodel/scripts/crossval-stbl.py | 119 +++++++++++++++++++++++++++++++ aimodel/scripts/stbl-crossval.py | 76 -------------------- 2 files changed, 119 insertions(+), 76 deletions(-) create mode 100755 aimodel/scripts/crossval-stbl.py delete mode 100755 aimodel/scripts/stbl-crossval.py diff --git a/aimodel/scripts/crossval-stbl.py b/aimodel/scripts/crossval-stbl.py new file mode 100755 index 0000000..672bffc --- /dev/null +++ b/aimodel/scripts/crossval-stbl.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +import os +import sys + +from loguru import logger +import pandas as pd +import scipy +from matplotlib import pyplot as plt + +# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them. +# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0. + + +if len(sys.argv) <= 1: + print(""" +Usage: + scripts/stbl-crossval.mjs {{path/to/directory}} + +...in which the given directory contains a series of experiment root directories to include in the statistical analysis. + +This script is not picky about the format of the data in metrics.tsv, so long as it's in the form: + +epoch metric_A metric_B … +0 val:float val:float … +1 val:float val:float … +2 val:float val:float … +⋮ +""") + sys.exit(0) + +DIRPATH = sys.argv[1] # [0] == script path + +files = 0 +metrics = {} +epochs = None + +for filepath in os.scandir(DIRPATH): + if not os.path.isdir(filepath): + continue + tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t") + + # metrics.append(tbl) + + for column in tbl.columns: + if column == "epoch": + if epochs is None: + epochs = tbl[column] + continue # Row index implicitly retains this + + if column not in metrics: + metrics[column] = [] + + metrics[column].append(tbl[column].values) + # print(column, tbl[column]) + + # print("DEBUG:metrics", tbl) + files += 1 + +logger.info(f"Read {files} files into crossval-stbl{files} analysis") + +stats = {} + +for metric in metrics.keys(): + metrics[metric] = pd.DataFrame(metrics[metric]).transpose() + + if metric not in stats: + stats[metric] = {} + + stats[metric]["mad"] = scipy.stats.median_abs_deviation( + metrics[metric], axis=1 + ) # median absolute deviation + stats[metric]["stddev"] = metrics[metric].std(axis=1) + stats[metric]["mean"] = metrics[metric].mean(axis=1) + stats[metric]["min"] = metrics[metric].min(axis=1) + stats[metric]["max"] = metrics[metric].max(axis=1) + stats[metric]["agg_min"] = stats[metric]["min"].min() + stats[metric]["agg_max"] = stats[metric]["max"].max() + stats[metric]["agg_stddev"] = metrics[metric].stack().std() + stats[metric]["agg_mean"] = metrics[metric].stack().std() + stats[metric]["agg_mad"] = scipy.stats.median_abs_deviation( + metrics[metric].stack() + ) # median absolute deviation + + # print(stats[metric]) + + plt.figure(figsize=(12, 8)) + plt.ylim(min(0, stats[metric]["agg_min"]), max(1, stats[metric]["agg_max"])) + plt.grid(visible=True, which="major", axis="y", linewidth=2) + plt.grid(visible=True, which="minor", axis="y", linewidth=1) + plt.minorticks_on() + plt.fill_between( + epochs, + stats[metric]["min"], + stats[metric]["max"], + alpha=0.2, + facecolor="#B7DE28", + edgecolor="#FDE724", + linestyle="dotted", + linewidth=1, + ) + plt.fill_between( + epochs, + stats[metric]["mean"] - stats[metric]["mad"], + stats[metric]["mean"] + stats[metric]["mad"], + alpha=0.5, + facecolor="#228A8D", + edgecolor="#3CBB74", + linestyle="dashed", + linewidth=1, + ) + plt.plot(epochs, stats[metric]["mean"], color="#450C54") + plt.title(f"{metric} // crossval-stbl{files}") + plt.xlabel("epoch") + plt.ylabel(metric) + plt.savefig(os.path.join(DIRPATH, f"crossval-stbl{files}_{metric}.png")) + plt.close() + +logger.success(f"Written {len(stats.keys())} graphs to {DIRPATH}") diff --git a/aimodel/scripts/stbl-crossval.py b/aimodel/scripts/stbl-crossval.py deleted file mode 100755 index 6536ddc..0000000 --- a/aimodel/scripts/stbl-crossval.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys - -from loguru import logger -import pandas as pd - -# This script analyses metrics.tsv files from a series of identical experiments and reports metrics on them. -# This is sometimes known as cross-validation, but we usually use the model series code crossval-stblX, where X is an integer >0. - - -if len(sys.argv) <= 1: - print(""" -Usage: - scripts/stbl-crossval.mjs {{path/to/directory}} - -...in which the given directory contains a series of experiment root directories to include in the statistical analysis. - -This script is not picky about the format of the data in metrics.tsv, so long as it's in the form: - -epoch metric_A metric_B … -0 val:float val:float … -1 val:float val:float … -2 val:float val:float … -⋮ -""") - sys.exit(0) - -DIRPATH = sys.argv[1] # [0] == script path - -files = 0 -metrics = {} - -for filepath in os.scandir(DIRPATH): - tbl = pd.read_csv(os.path.join(filepath, "metrics.tsv"), sep="\t") - - # metrics.append(tbl) - - for column in tbl.columns: - if column == "epoch": - continue # Row index implicitly retains this - - if column not in metrics: - metrics[column] = [] - - metrics[column].append(tbl[column].values) - # print(column, tbl[column]) - - # print("DEBUG:metrics", tbl) - files += 1 - -logger.info(f"Read {files} files into crossval-stbl{files} analysis") - -stats = {} - -for metric in metrics.keys(): - metrics[metric] = pd.DataFrame(metrics[metric]).transpose() - - if metric not in stats: - stats[metric] = {} - - stats[metric]["aad"] = metrics[metric].max(axis=1) # mean/average absolute deviation - stats[metric]["mad"] = metrics[metric].max(axis=1) # median absolute deviation - stats[metric]["stddev"] = metrics[metric].std(axis=1) - stats[metric]["mean"] = metrics[metric].mean(axis=1) - stats[metric]["min"] = metrics[metric].min(axis=1) - stats[metric]["max"] = metrics[metric].max(axis=1) - stats[metric]["agg_min"] = stats[metric]["min"].min() - stats[metric]["agg_max"] = stats[metric]["max"].max() - stats[metric]["agg_stddev"] = metrics[metric].stack().std() - stats[metric]["agg_mean"] = metrics[metric].stack().std() - stats[metric]["agg_aad"] = metrics[metric].stack().max() # mean/average absolute deviation - - print(stats[metric]) -