mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
if name == main
This commit is contained in:
parent
81cad8e6b4
commit
4a4df380e3
1 changed files with 26 additions and 20 deletions
|
@ -14,27 +14,33 @@ def plot_metric(ax, train, val, name, dir_output):
|
||||||
# plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
# plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
||||||
# plt.close()
|
# plt.close()
|
||||||
|
|
||||||
FILEPATH_INPUT = os.environ["INPUT"]
|
|
||||||
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
|
def plot_metrics(filepath_input, dirpath_output):
|
||||||
|
df = pd.read_csv(filepath_input, sep="\t")
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(10,13))
|
||||||
|
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())):
|
||||||
|
train = df[colname]
|
||||||
|
val = df[f"val_{colname}"]
|
||||||
|
|
||||||
|
colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \
|
||||||
|
.replace("one_hot_mean_iou", "mean iou")
|
||||||
|
|
||||||
|
ax = fig.add_subplot(3, 2, i+1)
|
||||||
|
|
||||||
|
plot_metric(ax, train, val, name=colname_display, dir_output=dirpath_output)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
target=os.path.join(dirpath_output, f"metrics.png")
|
||||||
|
plt.savefig(target)
|
||||||
|
|
||||||
|
print(f">>> Saved to {target}")
|
||||||
|
|
||||||
|
|
||||||
df = pd.read_csv(FILEPATH_INPUT, sep="\t")
|
if __name__ == "__main__":
|
||||||
|
FILEPATH_INPUT = os.environ["INPUT"]
|
||||||
|
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
|
||||||
|
|
||||||
fig = plt.figure(figsize=(10,13))
|
plot_metrics(FILEPATH_INPUT, DIRPATH_OUTPUT)
|
||||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())):
|
|
||||||
train = df[colname]
|
|
||||||
val = df[f"val_{colname}"]
|
|
||||||
|
|
||||||
colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \
|
|
||||||
.replace("one_hot_mean_iou", "mean iou")
|
|
||||||
|
|
||||||
ax = fig.add_subplot(3, 2, i+1)
|
|
||||||
|
|
||||||
plot_metric(ax, train, val, name=colname_display, dir_output=DIRPATH_OUTPUT)
|
|
||||||
|
|
||||||
fig.tight_layout()
|
|
||||||
|
|
||||||
target=os.path.join(DIRPATH_OUTPUT, f"metrics.png")
|
|
||||||
plt.savefig(target)
|
|
||||||
|
|
||||||
print(f">>> Saved to {target}")
|
|
||||||
|
|
Loading…
Reference in a new issue