diff --git a/aimodel/src/plot_metrics.py b/aimodel/src/plot_metrics.py index a43b65e..f668f0f 100755 --- a/aimodel/src/plot_metrics.py +++ b/aimodel/src/plot_metrics.py @@ -14,27 +14,33 @@ def plot_metric(ax, train, val, name, dir_output): # plt.savefig(os.path.join(dir_output, f"{name}.png")) # plt.close() -FILEPATH_INPUT = os.environ["INPUT"] -DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd() + +def plot_metrics(filepath_input, dirpath_output): + df = pd.read_csv(filepath_input, sep="\t") + + fig = plt.figure(figsize=(10,13)) + for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())): + train = df[colname] + val = df[f"val_{colname}"] + + colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \ + .replace("one_hot_mean_iou", "mean iou") + + ax = fig.add_subplot(3, 2, i+1) + + plot_metric(ax, train, val, name=colname_display, dir_output=dirpath_output) + + fig.tight_layout() + + target=os.path.join(dirpath_output, f"metrics.png") + plt.savefig(target) + + print(f">>> Saved to {target}") -df = pd.read_csv(FILEPATH_INPUT, sep="\t") - -fig = plt.figure(figsize=(10,13)) -for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())): - train = df[colname] - val = df[f"val_{colname}"] +if __name__ == "__main__": + FILEPATH_INPUT = os.environ["INPUT"] + DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd() - colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \ - .replace("one_hot_mean_iou", "mean iou") + plot_metrics(FILEPATH_INPUT, DIRPATH_OUTPUT) - ax = fig.add_subplot(3, 2, i+1) - - plot_metric(ax, train, val, name=colname_display, dir_output=DIRPATH_OUTPUT) - -fig.tight_layout() - -target=os.path.join(DIRPATH_OUTPUT, f"metrics.png") -plt.savefig(target) - -print(f">>> Saved to {target}")