mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 14:15:01 +00:00
finish script to plot generic metrics
This commit is contained in:
parent
698bbe2ffb
commit
1bd59dc038
1 changed files with 30 additions and 9 deletions
39
aimodel/src/plot_metrics.py
Normal file → Executable file
39
aimodel/src/plot_metrics.py
Normal file → Executable file
|
@ -1,19 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
def plot_metric(train, val, name, dir_output):
|
||||
plt.plot(train, label=f"train_{name}")
|
||||
plt.plot(val, label=f"val_{name}")
|
||||
plt.title(name)
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel(name)
|
||||
plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
||||
plt.close()
|
||||
def plot_metric(ax, train, val, name, dir_output):
|
||||
ax.plot(train, label=f"train_{name}")
|
||||
ax.plot(val, label=f"val_{name}")
|
||||
ax.set_title(name)
|
||||
ax.set_xlabel("epoch")
|
||||
ax.set_ylabel(name)
|
||||
# plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
||||
# plt.close()
|
||||
|
||||
FILEPATH_INPUT = os.environ["INPUT"]
|
||||
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
|
||||
|
||||
|
||||
df = pd.read_csv(FILEPATH_INPUT)
|
||||
df = pd.read_csv(FILEPATH_INPUT, sep="\t")
|
||||
|
||||
fig = plt.figure(figsize=(10,13))
|
||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())):
|
||||
train = df[colname]
|
||||
val = df[f"val_{colname}"]
|
||||
|
||||
colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \
|
||||
.replace("one_hot_mean_iou", "mean iou")
|
||||
|
||||
ax = fig.add_subplot(3, 2, i+1)
|
||||
|
||||
plot_metric(ax, train, val, name=colname_display, dir_output=DIRPATH_OUTPUT)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
target=os.path.join(DIRPATH_OUTPUT, f"metrics.png")
|
||||
plt.savefig(target)
|
||||
|
||||
print(f">>> Saved to {target}")
|
Loading…
Reference in a new issue