mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-29 20:33:00 +00:00
Add metrics every 64 batches
this is important, because with large batches it can be difficult to tell what's happening inside each epoch.
This commit is contained in:
parent
cf872ef739
commit
5f8d6dc6ea
2 changed files with 36 additions and 0 deletions
30
aimodel/src/lib/ai/components/CallbackNBatchCsv.py
Normal file
30
aimodel/src/lib/ai/components/CallbackNBatchCsv.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from lib.io.handle_open import handle_open
|
||||||
|
|
||||||
|
class CallbackNBatchCsv(tf.keras.callbacks.Callback):
|
||||||
|
def __init__(self, filepath, n_batches=1, separator="\t", **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.n_batches = n_batches
|
||||||
|
self.separator = separator
|
||||||
|
|
||||||
|
self.handle = handle_open(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
self.batches_seen = 0
|
||||||
|
self.keys = None
|
||||||
|
|
||||||
|
def write_header(self, logs): # logs = metrics
|
||||||
|
self.keys = logs.keys()
|
||||||
|
self.keys.sort()
|
||||||
|
self.handle.write("\t".join(self.keys)+"\n")
|
||||||
|
|
||||||
|
def on_batch_end(self, batch, logs=None): # logs = metrics
|
||||||
|
if self.batches_seen == 0:
|
||||||
|
self.write_header(logs)
|
||||||
|
|
||||||
|
if self.batches_seen % self.n_batches == 0:
|
||||||
|
self.handle.write(self.separator.join([str(logs[key]) for key in self.keys]) + "\n")
|
||||||
|
|
||||||
|
self.batches_seen += 1
|
|
@ -3,10 +3,12 @@ import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||||
|
from ..components.CallbackNBatchCsv import CallbackNBatchCsv
|
||||||
|
|
||||||
def make_callbacks(dirpath, model_predict):
|
def make_callbacks(dirpath, model_predict):
|
||||||
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
||||||
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
||||||
|
filepath_metrics_batch = os.path.join(dirpath, "metrics_batch64.tsv")
|
||||||
|
|
||||||
if not os.path.exists(dirpath_checkpoints):
|
if not os.path.exists(dirpath_checkpoints):
|
||||||
os.mkdir(dirpath_checkpoints)
|
os.mkdir(dirpath_checkpoints)
|
||||||
|
@ -24,5 +26,9 @@ def make_callbacks(dirpath, model_predict):
|
||||||
filename=filepath_metrics,
|
filename=filepath_metrics,
|
||||||
separator="\t"
|
separator="\t"
|
||||||
),
|
),
|
||||||
|
CallbackNBatchCsv(
|
||||||
|
filepath=filepath_metrics_batch,
|
||||||
|
n_batches=64
|
||||||
|
),
|
||||||
tf.keras.callbacks.ProgbarLogger(count_mode="steps") # batches
|
tf.keras.callbacks.ProgbarLogger(count_mode="steps") # batches
|
||||||
]
|
]
|
Loading…
Reference in a new issue