summary logger → summarywriter

This commit is contained in:
Starbeamrainbowlabs 2022-09-02 17:28:00 +01:00
parent 9f7f4af784
commit b9d01ddadc
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 6 additions and 4 deletions

View file

@ -28,6 +28,7 @@ class RainfallWaterContraster(object):
if filepath_checkpoint == None:
writefile(self.filepath_summary, "") # Empty the file ahead of time
self.model = self.make_model()
if self.dir_output == None:
raise Exception("Error: dir_output was not specified, and since no checkpoint was loaded training mode is activated.")
@ -36,7 +37,7 @@ class RainfallWaterContraster(object):
self.filepath_summary = os.path.join(self.dir_output, "summary.txt")
summarywriter(self.model, self.filepath_summary)
summarywriter(self.model, self.filepath_summary, append=True)
writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config()))
else:
self.model = self.load_model(filepath_checkpoint)

View file

@ -2,12 +2,12 @@ import tensorflow as tf
from loguru import logger
# from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from ..helpers.summarywriter import summarylogger
from ..helpers.summarywriter import summarywriter
from .convnext import make_convnext
class LayerContrastiveEncoder(tf.keras.layers.Layer):
def __init__(self, input_width, input_height, channels, feature_dim=2048, **kwargs):
def __init__(self, input_width, input_height, channels, summary_file=None, feature_dim=2048, **kwargs):
"""Creates a new contrastive learning encoder layer.
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
@ -42,7 +42,8 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
# """
# self.embedding = tf.keras.layers.Dense(self.param_feature_dim)
summarylogger(self.encoder)
if summary_file:
summarywriter(self.encoder, append=True)
def get_config(self):
config = super(LayerContrastiveEncoder, self).get_config()