ai: implement saving only the rainfall encoder

This commit is contained in:
Starbeamrainbowlabs 2022-09-06 19:48:46 +01:00
parent 4c4358c3e5
commit 22620a1854
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
5 changed files with 49 additions and 18 deletions

View file

@ -36,12 +36,12 @@ class RainfallWaterContraster(object):
self.filepath_summary = os.path.join(self.dir_output, "summary.txt") self.filepath_summary = os.path.join(self.dir_output, "summary.txt")
writefile(self.filepath_summary, "") # Empty the file ahead of time writefile(self.filepath_summary, "") # Empty the file ahead of time
self.model = self.make_model() self.make_model()
summarywriter(self.model, self.filepath_summary, append=True) summarywriter(self.model, self.filepath_summary, append=True)
writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config())) writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config()))
else: else:
self.model = self.load_model(filepath_checkpoint) self.load_model(filepath_checkpoint)
def get_config(self): def get_config(self):
return { return {
@ -59,7 +59,7 @@ class RainfallWaterContraster(object):
def make_model(self): def make_model(self):
return model_rainfallwater_contrastive( self.model, self.model_predict = model_rainfallwater_contrastive(
batch_size=self.batch_size, batch_size=self.batch_size,
summary_file=self.filepath_summary, summary_file=self.filepath_summary,
**self.kwargs **self.kwargs
@ -72,7 +72,7 @@ class RainfallWaterContraster(object):
filepath_checkpoint (string): The filepath to load the saved model from. filepath_checkpoint (string): The filepath to load the saved model from.
""" """
return tf.keras.models.load_model(filepath_checkpoint, custom_objects={ self.model_predict = tf.keras.models.load_model(filepath_checkpoint, custom_objects={
"LayerContrastiveEncoder": LayerContrastiveEncoder, "LayerContrastiveEncoder": LayerContrastiveEncoder,
"LayerCheeseMultipleOut": LayerCheeseMultipleOut "LayerCheeseMultipleOut": LayerCheeseMultipleOut
}) })
@ -84,7 +84,8 @@ class RainfallWaterContraster(object):
dataset_train, dataset_train,
validation_data=dataset_validate, validation_data=dataset_validate,
epochs=self.epochs, epochs=self.epochs,
callbacks=make_callbacks(self.dir_output) callbacks=make_callbacks(self.dir_output, self.model_predict),
steps_per_epoch=10 # For testing
) )
def embed(self, dataset): def embed(self, dataset):
@ -92,7 +93,7 @@ class RainfallWaterContraster(object):
i_batch = -1 i_batch = -1
for batch in dataset: for batch in dataset:
i_batch += 1 i_batch += 1
result_batch = self.model(batch[0]) result_batch = self.model(batch[0]) # ((rainfall, water), dummy_label)
rainfall, water = tf.unstack(result_batch, axis=-2) rainfall, water = tf.unstack(result_batch, axis=-2)
rainfall = tf.unstack(rainfall, axis=0) rainfall = tf.unstack(rainfall, axis=0)
@ -100,4 +101,11 @@ class RainfallWaterContraster(object):
result.extend(zip(rainfall, water)) result.extend(zip(rainfall, water))
return result
def embed_rainfall(self, dataset):
result = []
for batch in dataset:
result_batch = self.model_predict(batch)
result.extend(tf.unstack(result_batch, axis=0))
return result return result

View file

@ -0,0 +1,15 @@
from loguru import logger
import tensorflow as tf
class CallbackCustomModelCheckpoint(tf.keras.callbacks.Callback):
def __init__(self, model_to_checkpoint, **kwargs) -> None:
super().__init__()
self.model_to_checkpoint = model_to_checkpoint
self.checkpointer = tf.keras.callbacks.ModelCheckpoint(**kwargs)
def on_epoch_end(self, epoch, logs=None):
logger.info("Saving checkpoint")
self.checkpointer.set_model(self.model_to_checkpoint)
self.checkpointer.on_epoch_end(epoch=epoch, logs=logs)
logger.info("Checkpoint saved successfully")

View file

@ -8,10 +8,10 @@ class LossContrastive(tf.keras.losses.Loss):
def call(self, y_true, y_pred): def call(self, y_true, y_pred):
rainfall, water = tf.unstack(y_pred, axis=-2) rainfall, water = tf.unstack(y_pred, axis=-2)
print("LOSS:call y_true", y_true.shape) # print("LOSS:call y_true", y_true.shape)
print("LOSS:call y_pred", y_pred.shape) # print("LOSS:call y_pred", y_pred.shape)
print("BEFORE_RESHAPE rainfall", rainfall) # print("BEFORE_RESHAPE rainfall", rainfall)
print("BEFORE_RESHAPE water", water) # print("BEFORE_RESHAPE water", water)
# # Ensure the shapes are defined # # Ensure the shapes are defined
# rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]]) # rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]])
@ -20,7 +20,7 @@ class LossContrastive(tf.keras.losses.Loss):
logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100) logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100)
print("LOGITS", logits) # print("LOGITS", logits)
labels = tf.eye(self.batch_size, dtype=tf.int32) labels = tf.eye(self.batch_size, dtype=tf.int32)
loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0) loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0)
@ -29,9 +29,9 @@ class LossContrastive(tf.keras.losses.Loss):
loss = (loss_rainfall + loss_water) / 2 loss = (loss_rainfall + loss_water) / 2
# cosine_similarity results in tensor of range -1 - 1, but tf.sparse.eye has range 0 - 1 # cosine_similarity results in tensor of range -1 - 1, but tf.sparse.eye has range 0 - 1
print("LABELS", labels) # print("LABELS", labels)
print("LOSS_rainfall", loss_rainfall) # print("LOSS_rainfall", loss_rainfall)
print("LOSS_water", loss_water) # print("LOSS_water", loss_water)
print("LOSS", loss) # print("LOSS", loss)
return loss return loss

View file

@ -2,7 +2,9 @@ import os
import tensorflow as tf import tensorflow as tf
def make_callbacks(dirpath): from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
def make_callbacks(dirpath, model_predict):
dirpath_checkpoints = os.path.join(dirpath, "checkpoints") dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
filepath_metrics = os.path.join(dirpath, "metrics.tsv") filepath_metrics = os.path.join(dirpath, "metrics.tsv")
@ -10,7 +12,8 @@ def make_callbacks(dirpath):
os.mkdir(dirpath_checkpoints) os.mkdir(dirpath_checkpoints)
return [ return [
tf.keras.callbacks.ModelCheckpoint( CallbackCustomModelCheckpoint(
model_to_checkpoint=model_predict,
filepath=os.path.join( filepath=os.path.join(
dirpath_checkpoints, dirpath_checkpoints,
"checkpoint_weights_e{epoch:d}_loss{loss:.3f}.hdf5" "checkpoint_weights_e{epoch:d}_loss{loss:.3f}.hdf5"

View file

@ -63,4 +63,9 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
loss=LossContrastive(batch_size=batch_size, weight_temperature=weight_temperature) loss=LossContrastive(batch_size=batch_size, weight_temperature=weight_temperature)
) )
return model model_predict = tf.keras.Model(
inputs = input_rainfall,
outputs = rainfall
)
return model, model_predict