if name == main

This commit is contained in:
Starbeamrainbowlabs 2023-03-23 18:09:52 +00:00
parent 81cad8e6b4
commit 4a4df380e3
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -14,11 +14,9 @@ def plot_metric(ax, train, val, name, dir_output):
# plt.savefig(os.path.join(dir_output, f"{name}.png")) # plt.savefig(os.path.join(dir_output, f"{name}.png"))
# plt.close() # plt.close()
FILEPATH_INPUT = os.environ["INPUT"]
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
def plot_metrics(filepath_input, dirpath_output):
df = pd.read_csv(FILEPATH_INPUT, sep="\t") df = pd.read_csv(filepath_input, sep="\t")
fig = plt.figure(figsize=(10,13)) fig = plt.figure(figsize=(10,13))
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())): for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), df.columns.values.tolist())):
@ -30,11 +28,19 @@ for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not co
ax = fig.add_subplot(3, 2, i+1) ax = fig.add_subplot(3, 2, i+1)
plot_metric(ax, train, val, name=colname_display, dir_output=DIRPATH_OUTPUT) plot_metric(ax, train, val, name=colname_display, dir_output=dirpath_output)
fig.tight_layout() fig.tight_layout()
target=os.path.join(DIRPATH_OUTPUT, f"metrics.png") target=os.path.join(dirpath_output, f"metrics.png")
plt.savefig(target) plt.savefig(target)
print(f">>> Saved to {target}") print(f">>> Saved to {target}")
if __name__ == "__main__":
FILEPATH_INPUT = os.environ["INPUT"]
DIRPATH_OUTPUT = os.environ["OUTPUT"] if "OUTPUT" in os.environ else os.getcwd()
plot_metrics(FILEPATH_INPUT, DIRPATH_OUTPUT)