mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
Add metrics every 64 batches
this is important, because with large batches it can be difficult to tell what's happening inside each epoch.
This commit is contained in:
parent
cf872ef739
commit
5f8d6dc6ea
2 changed files with 36 additions and 0 deletions
30
aimodel/src/lib/ai/components/CallbackNBatchCsv.py
Normal file
30
aimodel/src/lib/ai/components/CallbackNBatchCsv.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import tensorflow as tf
|
||||
|
||||
from lib.io.handle_open import handle_open
|
||||
|
||||
class CallbackNBatchCsv(tf.keras.callbacks.Callback):
|
||||
def __init__(self, filepath, n_batches=1, separator="\t", **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.n_batches = n_batches
|
||||
self.separator = separator
|
||||
|
||||
self.handle = handle_open(filepath)
|
||||
|
||||
|
||||
self.batches_seen = 0
|
||||
self.keys = None
|
||||
|
||||
def write_header(self, logs): # logs = metrics
|
||||
self.keys = logs.keys()
|
||||
self.keys.sort()
|
||||
self.handle.write("\t".join(self.keys)+"\n")
|
||||
|
||||
def on_batch_end(self, batch, logs=None): # logs = metrics
|
||||
if self.batches_seen == 0:
|
||||
self.write_header(logs)
|
||||
|
||||
if self.batches_seen % self.n_batches == 0:
|
||||
self.handle.write(self.separator.join([str(logs[key]) for key in self.keys]) + "\n")
|
||||
|
||||
self.batches_seen += 1
|
|
@ -3,10 +3,12 @@ import os
|
|||
import tensorflow as tf
|
||||
|
||||
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||
from ..components.CallbackNBatchCsv import CallbackNBatchCsv
|
||||
|
||||
def make_callbacks(dirpath, model_predict):
|
||||
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
||||
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
||||
filepath_metrics_batch = os.path.join(dirpath, "metrics_batch64.tsv")
|
||||
|
||||
if not os.path.exists(dirpath_checkpoints):
|
||||
os.mkdir(dirpath_checkpoints)
|
||||
|
@ -24,5 +26,9 @@ def make_callbacks(dirpath, model_predict):
|
|||
filename=filepath_metrics,
|
||||
separator="\t"
|
||||
),
|
||||
CallbackNBatchCsv(
|
||||
filepath=filepath_metrics_batch,
|
||||
n_batches=64
|
||||
),
|
||||
tf.keras.callbacks.ProgbarLogger(count_mode="steps") # batches
|
||||
]
|
Loading…
Reference in a new issue