mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
ai: implement saving only the rainfall encoder
This commit is contained in:
parent
4c4358c3e5
commit
22620a1854
5 changed files with 49 additions and 18 deletions
|
@ -36,12 +36,12 @@ class RainfallWaterContraster(object):
|
|||
self.filepath_summary = os.path.join(self.dir_output, "summary.txt")
|
||||
|
||||
writefile(self.filepath_summary, "") # Empty the file ahead of time
|
||||
self.model = self.make_model()
|
||||
self.make_model()
|
||||
|
||||
summarywriter(self.model, self.filepath_summary, append=True)
|
||||
writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config()))
|
||||
else:
|
||||
self.model = self.load_model(filepath_checkpoint)
|
||||
self.load_model(filepath_checkpoint)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
|
@ -59,7 +59,7 @@ class RainfallWaterContraster(object):
|
|||
|
||||
|
||||
def make_model(self):
|
||||
return model_rainfallwater_contrastive(
|
||||
self.model, self.model_predict = model_rainfallwater_contrastive(
|
||||
batch_size=self.batch_size,
|
||||
summary_file=self.filepath_summary,
|
||||
**self.kwargs
|
||||
|
@ -72,7 +72,7 @@ class RainfallWaterContraster(object):
|
|||
filepath_checkpoint (string): The filepath to load the saved model from.
|
||||
"""
|
||||
|
||||
return tf.keras.models.load_model(filepath_checkpoint, custom_objects={
|
||||
self.model_predict = tf.keras.models.load_model(filepath_checkpoint, custom_objects={
|
||||
"LayerContrastiveEncoder": LayerContrastiveEncoder,
|
||||
"LayerCheeseMultipleOut": LayerCheeseMultipleOut
|
||||
})
|
||||
|
@ -84,7 +84,8 @@ class RainfallWaterContraster(object):
|
|||
dataset_train,
|
||||
validation_data=dataset_validate,
|
||||
epochs=self.epochs,
|
||||
callbacks=make_callbacks(self.dir_output)
|
||||
callbacks=make_callbacks(self.dir_output, self.model_predict),
|
||||
steps_per_epoch=10 # For testing
|
||||
)
|
||||
|
||||
def embed(self, dataset):
|
||||
|
@ -92,7 +93,7 @@ class RainfallWaterContraster(object):
|
|||
i_batch = -1
|
||||
for batch in dataset:
|
||||
i_batch += 1
|
||||
result_batch = self.model(batch[0])
|
||||
result_batch = self.model(batch[0]) # ((rainfall, water), dummy_label)
|
||||
rainfall, water = tf.unstack(result_batch, axis=-2)
|
||||
|
||||
rainfall = tf.unstack(rainfall, axis=0)
|
||||
|
@ -100,4 +101,11 @@ class RainfallWaterContraster(object):
|
|||
|
||||
result.extend(zip(rainfall, water))
|
||||
|
||||
return result
|
||||
|
||||
def embed_rainfall(self, dataset):
|
||||
result = []
|
||||
for batch in dataset:
|
||||
result_batch = self.model_predict(batch)
|
||||
result.extend(tf.unstack(result_batch, axis=0))
|
||||
return result
|
|
@ -0,0 +1,15 @@
|
|||
from loguru import logger
|
||||
import tensorflow as tf
|
||||
|
||||
class CallbackCustomModelCheckpoint(tf.keras.callbacks.Callback):
|
||||
def __init__(self, model_to_checkpoint, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_to_checkpoint = model_to_checkpoint
|
||||
self.checkpointer = tf.keras.callbacks.ModelCheckpoint(**kwargs)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
logger.info("Saving checkpoint")
|
||||
self.checkpointer.set_model(self.model_to_checkpoint)
|
||||
self.checkpointer.on_epoch_end(epoch=epoch, logs=logs)
|
||||
logger.info("Checkpoint saved successfully")
|
|
@ -8,10 +8,10 @@ class LossContrastive(tf.keras.losses.Loss):
|
|||
|
||||
def call(self, y_true, y_pred):
|
||||
rainfall, water = tf.unstack(y_pred, axis=-2)
|
||||
print("LOSS:call y_true", y_true.shape)
|
||||
print("LOSS:call y_pred", y_pred.shape)
|
||||
print("BEFORE_RESHAPE rainfall", rainfall)
|
||||
print("BEFORE_RESHAPE water", water)
|
||||
# print("LOSS:call y_true", y_true.shape)
|
||||
# print("LOSS:call y_pred", y_pred.shape)
|
||||
# print("BEFORE_RESHAPE rainfall", rainfall)
|
||||
# print("BEFORE_RESHAPE water", water)
|
||||
|
||||
# # Ensure the shapes are defined
|
||||
# rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]])
|
||||
|
@ -20,7 +20,7 @@ class LossContrastive(tf.keras.losses.Loss):
|
|||
|
||||
logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100)
|
||||
|
||||
print("LOGITS", logits)
|
||||
# print("LOGITS", logits)
|
||||
|
||||
labels = tf.eye(self.batch_size, dtype=tf.int32)
|
||||
loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0)
|
||||
|
@ -29,9 +29,9 @@ class LossContrastive(tf.keras.losses.Loss):
|
|||
loss = (loss_rainfall + loss_water) / 2
|
||||
|
||||
# cosine_similarity results in tensor of range -1 - 1, but tf.sparse.eye has range 0 - 1
|
||||
print("LABELS", labels)
|
||||
print("LOSS_rainfall", loss_rainfall)
|
||||
print("LOSS_water", loss_water)
|
||||
print("LOSS", loss)
|
||||
# print("LABELS", labels)
|
||||
# print("LOSS_rainfall", loss_rainfall)
|
||||
# print("LOSS_water", loss_water)
|
||||
# print("LOSS", loss)
|
||||
return loss
|
||||
|
|
@ -2,7 +2,9 @@ import os
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
def make_callbacks(dirpath):
|
||||
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||
|
||||
def make_callbacks(dirpath, model_predict):
|
||||
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
||||
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
||||
|
||||
|
@ -10,7 +12,8 @@ def make_callbacks(dirpath):
|
|||
os.mkdir(dirpath_checkpoints)
|
||||
|
||||
return [
|
||||
tf.keras.callbacks.ModelCheckpoint(
|
||||
CallbackCustomModelCheckpoint(
|
||||
model_to_checkpoint=model_predict,
|
||||
filepath=os.path.join(
|
||||
dirpath_checkpoints,
|
||||
"checkpoint_weights_e{epoch:d}_loss{loss:.3f}.hdf5"
|
||||
|
|
|
@ -63,4 +63,9 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
|
|||
loss=LossContrastive(batch_size=batch_size, weight_temperature=weight_temperature)
|
||||
)
|
||||
|
||||
return model
|
||||
model_predict = tf.keras.Model(
|
||||
inputs = input_rainfall,
|
||||
outputs = rainfall
|
||||
)
|
||||
|
||||
return model, model_predict
|
Loading…
Reference in a new issue