mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
ai: implement saving only the rainfall encoder
This commit is contained in:
parent
4c4358c3e5
commit
22620a1854
5 changed files with 49 additions and 18 deletions
|
@ -36,12 +36,12 @@ class RainfallWaterContraster(object):
|
||||||
self.filepath_summary = os.path.join(self.dir_output, "summary.txt")
|
self.filepath_summary = os.path.join(self.dir_output, "summary.txt")
|
||||||
|
|
||||||
writefile(self.filepath_summary, "") # Empty the file ahead of time
|
writefile(self.filepath_summary, "") # Empty the file ahead of time
|
||||||
self.model = self.make_model()
|
self.make_model()
|
||||||
|
|
||||||
summarywriter(self.model, self.filepath_summary, append=True)
|
summarywriter(self.model, self.filepath_summary, append=True)
|
||||||
writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config()))
|
writefile(os.path.join(self.dir_output, "params.json"), json.dumps(self.get_config()))
|
||||||
else:
|
else:
|
||||||
self.model = self.load_model(filepath_checkpoint)
|
self.load_model(filepath_checkpoint)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
|
@ -59,7 +59,7 @@ class RainfallWaterContraster(object):
|
||||||
|
|
||||||
|
|
||||||
def make_model(self):
|
def make_model(self):
|
||||||
return model_rainfallwater_contrastive(
|
self.model, self.model_predict = model_rainfallwater_contrastive(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
summary_file=self.filepath_summary,
|
summary_file=self.filepath_summary,
|
||||||
**self.kwargs
|
**self.kwargs
|
||||||
|
@ -72,7 +72,7 @@ class RainfallWaterContraster(object):
|
||||||
filepath_checkpoint (string): The filepath to load the saved model from.
|
filepath_checkpoint (string): The filepath to load the saved model from.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return tf.keras.models.load_model(filepath_checkpoint, custom_objects={
|
self.model_predict = tf.keras.models.load_model(filepath_checkpoint, custom_objects={
|
||||||
"LayerContrastiveEncoder": LayerContrastiveEncoder,
|
"LayerContrastiveEncoder": LayerContrastiveEncoder,
|
||||||
"LayerCheeseMultipleOut": LayerCheeseMultipleOut
|
"LayerCheeseMultipleOut": LayerCheeseMultipleOut
|
||||||
})
|
})
|
||||||
|
@ -84,7 +84,8 @@ class RainfallWaterContraster(object):
|
||||||
dataset_train,
|
dataset_train,
|
||||||
validation_data=dataset_validate,
|
validation_data=dataset_validate,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
callbacks=make_callbacks(self.dir_output)
|
callbacks=make_callbacks(self.dir_output, self.model_predict),
|
||||||
|
steps_per_epoch=10 # For testing
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed(self, dataset):
|
def embed(self, dataset):
|
||||||
|
@ -92,7 +93,7 @@ class RainfallWaterContraster(object):
|
||||||
i_batch = -1
|
i_batch = -1
|
||||||
for batch in dataset:
|
for batch in dataset:
|
||||||
i_batch += 1
|
i_batch += 1
|
||||||
result_batch = self.model(batch[0])
|
result_batch = self.model(batch[0]) # ((rainfall, water), dummy_label)
|
||||||
rainfall, water = tf.unstack(result_batch, axis=-2)
|
rainfall, water = tf.unstack(result_batch, axis=-2)
|
||||||
|
|
||||||
rainfall = tf.unstack(rainfall, axis=0)
|
rainfall = tf.unstack(rainfall, axis=0)
|
||||||
|
@ -100,4 +101,11 @@ class RainfallWaterContraster(object):
|
||||||
|
|
||||||
result.extend(zip(rainfall, water))
|
result.extend(zip(rainfall, water))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def embed_rainfall(self, dataset):
|
||||||
|
result = []
|
||||||
|
for batch in dataset:
|
||||||
|
result_batch = self.model_predict(batch)
|
||||||
|
result.extend(tf.unstack(result_batch, axis=0))
|
||||||
return result
|
return result
|
|
@ -0,0 +1,15 @@
|
||||||
|
from loguru import logger
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
class CallbackCustomModelCheckpoint(tf.keras.callbacks.Callback):
|
||||||
|
def __init__(self, model_to_checkpoint, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model_to_checkpoint = model_to_checkpoint
|
||||||
|
self.checkpointer = tf.keras.callbacks.ModelCheckpoint(**kwargs)
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
logger.info("Saving checkpoint")
|
||||||
|
self.checkpointer.set_model(self.model_to_checkpoint)
|
||||||
|
self.checkpointer.on_epoch_end(epoch=epoch, logs=logs)
|
||||||
|
logger.info("Checkpoint saved successfully")
|
|
@ -8,10 +8,10 @@ class LossContrastive(tf.keras.losses.Loss):
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
rainfall, water = tf.unstack(y_pred, axis=-2)
|
rainfall, water = tf.unstack(y_pred, axis=-2)
|
||||||
print("LOSS:call y_true", y_true.shape)
|
# print("LOSS:call y_true", y_true.shape)
|
||||||
print("LOSS:call y_pred", y_pred.shape)
|
# print("LOSS:call y_pred", y_pred.shape)
|
||||||
print("BEFORE_RESHAPE rainfall", rainfall)
|
# print("BEFORE_RESHAPE rainfall", rainfall)
|
||||||
print("BEFORE_RESHAPE water", water)
|
# print("BEFORE_RESHAPE water", water)
|
||||||
|
|
||||||
# # Ensure the shapes are defined
|
# # Ensure the shapes are defined
|
||||||
# rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]])
|
# rainfall = tf.reshape(rainfall, [self.batch_size, rainfall.shape[1]])
|
||||||
|
@ -20,7 +20,7 @@ class LossContrastive(tf.keras.losses.Loss):
|
||||||
|
|
||||||
logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100)
|
logits = tf.linalg.matmul(rainfall, tf.transpose(water)) * tf.clip_by_value(tf.math.exp(self.weight_temperature), 0, 100)
|
||||||
|
|
||||||
print("LOGITS", logits)
|
# print("LOGITS", logits)
|
||||||
|
|
||||||
labels = tf.eye(self.batch_size, dtype=tf.int32)
|
labels = tf.eye(self.batch_size, dtype=tf.int32)
|
||||||
loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0)
|
loss_rainfall = tf.keras.metrics.binary_crossentropy(labels, logits, from_logits=True, axis=0)
|
||||||
|
@ -29,9 +29,9 @@ class LossContrastive(tf.keras.losses.Loss):
|
||||||
loss = (loss_rainfall + loss_water) / 2
|
loss = (loss_rainfall + loss_water) / 2
|
||||||
|
|
||||||
# cosine_similarity results in tensor of range -1 - 1, but tf.sparse.eye has range 0 - 1
|
# cosine_similarity results in tensor of range -1 - 1, but tf.sparse.eye has range 0 - 1
|
||||||
print("LABELS", labels)
|
# print("LABELS", labels)
|
||||||
print("LOSS_rainfall", loss_rainfall)
|
# print("LOSS_rainfall", loss_rainfall)
|
||||||
print("LOSS_water", loss_water)
|
# print("LOSS_water", loss_water)
|
||||||
print("LOSS", loss)
|
# print("LOSS", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -2,7 +2,9 @@ import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def make_callbacks(dirpath):
|
from ..components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||||
|
|
||||||
|
def make_callbacks(dirpath, model_predict):
|
||||||
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
dirpath_checkpoints = os.path.join(dirpath, "checkpoints")
|
||||||
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
filepath_metrics = os.path.join(dirpath, "metrics.tsv")
|
||||||
|
|
||||||
|
@ -10,7 +12,8 @@ def make_callbacks(dirpath):
|
||||||
os.mkdir(dirpath_checkpoints)
|
os.mkdir(dirpath_checkpoints)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
tf.keras.callbacks.ModelCheckpoint(
|
CallbackCustomModelCheckpoint(
|
||||||
|
model_to_checkpoint=model_predict,
|
||||||
filepath=os.path.join(
|
filepath=os.path.join(
|
||||||
dirpath_checkpoints,
|
dirpath_checkpoints,
|
||||||
"checkpoint_weights_e{epoch:d}_loss{loss:.3f}.hdf5"
|
"checkpoint_weights_e{epoch:d}_loss{loss:.3f}.hdf5"
|
||||||
|
|
|
@ -63,4 +63,9 @@ def model_rainfallwater_contrastive(metadata, shape_water, batch_size=64, featur
|
||||||
loss=LossContrastive(batch_size=batch_size, weight_temperature=weight_temperature)
|
loss=LossContrastive(batch_size=batch_size, weight_temperature=weight_temperature)
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
model_predict = tf.keras.Model(
|
||||||
|
inputs = input_rainfall,
|
||||||
|
outputs = rainfall
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, model_predict
|
Loading…
Reference in a new issue