mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-16 14:43:01 +00:00
plot_metrics_multi: add option to plot train/val separately
This commit is contained in:
parent
b5b26d980b
commit
e44b5533b1
1 changed files with 27 additions and 6 deletions
|
@ -33,15 +33,29 @@ def plot_metric(ax, train_list, val_list, metric_name, model_names, dir_output):
|
||||||
# plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
# plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
||||||
# plt.close()
|
# plt.close()
|
||||||
|
|
||||||
|
def plot_metric_single(ax, data_list, metric_name, model_names, dir_output):
|
||||||
|
i = 0
|
||||||
|
for item in data_list:
|
||||||
|
ax.plot(item, label=model_names[i], linewidth=1)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
ax.set_title(metric_name)
|
||||||
|
ax.set_xlabel("epoch")
|
||||||
|
ax.set_ylabel(metric_name)
|
||||||
|
# plt.savefig(os.path.join(dir_output, f"{name}.png"))
|
||||||
|
# plt.close()
|
||||||
|
|
||||||
|
|
||||||
def make_dfs(filepaths_input):
|
def make_dfs(filepaths_input):
|
||||||
dfs = []
|
dfs = []
|
||||||
for filepath_input in filepaths_input:
|
for filepath_input in filepaths_input:
|
||||||
print("DEBUG filepath_input", filepath_input)
|
print("DEBUG filepath_input", filepath_input)
|
||||||
dfs = pd.read_csv(filepath_input, sep="\t")
|
dfs = pd.read_csv(filepath_input, sep="\t")
|
||||||
|
|
||||||
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
|
||||||
|
def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1, train_val_separate=False):
|
||||||
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
dfs = [ pd.read_csv(filepath_input, sep="\t") for filepath_input in filepaths_input ]
|
||||||
matplotlib.rcParams.update({'font.size': 15*resolution})
|
matplotlib.rcParams.update({'font.size': 15*resolution*(0.5 if train_val_separate else 1)})
|
||||||
fig = plt.figure(figsize=(10*resolution, 14*resolution))
|
fig = plt.figure(figsize=(10*resolution, 14*resolution))
|
||||||
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
for i, colname in enumerate(filter(lambda colname: colname != "epoch" and not colname.startswith("val_"), dfs[0].columns.values.tolist())):
|
||||||
train = [ df[colname] for df in dfs ]
|
train = [ df[colname] for df in dfs ]
|
||||||
|
@ -50,11 +64,17 @@ def plot_metrics(filepaths_input, model_names, dirpath_output, resolution=1):
|
||||||
colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \
|
colname_display = colname.replace("metric_dice_coefficient", "dice coefficient") \
|
||||||
.replace("one_hot_mean_iou", "mean iou")
|
.replace("one_hot_mean_iou", "mean iou")
|
||||||
|
|
||||||
ax = fig.add_subplot(3, 2, i+1)
|
if train_val_separate:
|
||||||
|
ax = fig.add_subplot(3*2, 2*2, (i+1)*2 - 1)
|
||||||
|
plot_metric_single(ax, train, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
|
||||||
|
ax = fig.add_subplot(3*2, 2*2, (i+1)*2)
|
||||||
|
plot_metric_single(ax, val, metric_name=f"val_{colname_display}", model_names=model_names, dir_output=dirpath_output)
|
||||||
|
else:
|
||||||
|
ax = fig.add_subplot(3 * 2, 2, i+1)
|
||||||
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
|
plot_metric(ax, train, val, metric_name=colname_display, model_names=model_names, dir_output=dirpath_output)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Ref https://stackoverflow.com/a/57484812/1460422
|
# Ref https://stackoverflow.com/a/57484812/1460422
|
||||||
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
|
# lines_labels = [ ax.get_legend_handles_labels() for ax in fig.axes ]
|
||||||
lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
|
lines_labels = [ fig.axes[0].get_legend_handles_labels() ]
|
||||||
|
@ -91,7 +111,7 @@ Usage:
|
||||||
echo -e "filepathA\\nfilepathB..." | [OUTPUT="path/to/output_dir"] [REGEX_NAME=''] path/to/plot_metrics_multi.py
|
echo -e "filepathA\\nfilepathB..." | [OUTPUT="path/to/output_dir"] [REGEX_NAME=''] path/to/plot_metrics_multi.py
|
||||||
""")
|
""")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
TRAIN_VAL_SEPARATE = True if "TRAIN_VAL_SEPARATE" in os.environ else False
|
||||||
REGEX_NAME = os.environ["REGEX_NAME"] if "REGEX_NAME" in os.environ else None
|
REGEX_NAME = os.environ["REGEX_NAME"] if "REGEX_NAME" in os.environ else None
|
||||||
if REGEX_NAME is None and len(sys.argv) >= 1:
|
if REGEX_NAME is None and len(sys.argv) >= 1:
|
||||||
REGEX_NAME = sys.argv[1]
|
REGEX_NAME = sys.argv[1]
|
||||||
|
@ -126,6 +146,7 @@ Usage:
|
||||||
FILEPATHS_INPUT,
|
FILEPATHS_INPUT,
|
||||||
MODEL_NAMES,
|
MODEL_NAMES,
|
||||||
DIRPATH_OUTPUT,
|
DIRPATH_OUTPUT,
|
||||||
resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1
|
resolution=float(os.environ["RESOLUTION"]) if "RESOLUTION" in os.environ else 1,
|
||||||
|
train_val_separate=TRAIN_VAL_SEPARATE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue