Add metrics every 64 batches

this is important, because with large batches it can be difficult to tell what's happening inside each epoch.
This commit is contained in:
Starbeamrainbowlabs 2022-10-31 19:26:10 +00:00
parent cf872ef739
commit 5f8d6dc6ea
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 36 additions and 0 deletions

View file

@ -0,0 +1,30 @@
import tensorflow as tf
from lib.io.handle_open import handle_open
class CallbackNBatchCsv(tf.keras.callbacks.Callback):
def __init__(self, filepath, n_batches=1, separator="\t", **kwargs) -> None:
super().__init__(**kwargs)
self.n_batches = n_batches
self.separator = separator
self.handle = handle_open(filepath)
self.batches_seen = 0
self.keys = None
def write_header(self, logs): # logs = metrics
self.keys = logs.keys()
self.keys.sort()
self.handle.write("\t".join(self.keys)+"\n")
def on_batch_end(self, batch, logs=None): # logs = metrics
if self.batches_seen == 0:
self.write_header(logs)
if self.batches_seen % self.n_batches == 0:
self.handle.write(self.separator.join([str(logs[key]) for key in self.keys]) + "\n")
self.batches_seen += 1

View file

@ -3,10 +3,12 @@ import os
import tensorflow as tf import tensorflow as tf
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
from ..components.CallbackNBatchCsv import CallbackNBatchCsv
def make_callbacks(dirpath, model_predict): def make_callbacks(dirpath, model_predict):
dirpath_checkpoints = os.path.join(dirpath, "checkpoints") dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
filepath_metrics = os.path.join(dirpath, "metrics.tsv") filepath_metrics = os.path.join(dirpath, "metrics.tsv")
filepath_metrics_batch = os.path.join(dirpath, "metrics_batch64.tsv")
if not os.path.exists(dirpath_checkpoints): if not os.path.exists(dirpath_checkpoints):
os.mkdir(dirpath_checkpoints) os.mkdir(dirpath_checkpoints)
@ -24,5 +26,9 @@ def make_callbacks(dirpath, model_predict):
filename=filepath_metrics, filename=filepath_metrics,
separator="\t" separator="\t"
), ),
CallbackNBatchCsv(
filepath=filepath_metrics_batch,
n_batches=64
),
tf.keras.callbacks.ProgbarLogger(count_mode="steps") # batches tf.keras.callbacks.ProgbarLogger(count_mode="steps") # batches
] ]