From f1d7973f22268b26f897e5b66db6caae7b3aa8ff Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 1 Sep 2022 17:01:00 +0100 Subject: [PATCH] ai: add dummy label --- aimodel/src/lib/ai/components/LayerContrastiveEncoder.py | 4 ++-- aimodel/src/lib/dataset/dataset.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py index 2a2d94c..b5c2a3b 100644 --- a/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py +++ b/aimodel/src/lib/ai/components/LayerContrastiveEncoder.py @@ -52,8 +52,8 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer): # super().build(input_shape=input_shape[0]) # self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ])) - def call(self, input_thing): - result = self.encoder(input_thing) + def call(self, input_thing, training=False): + result = self.encoder(input_thing, training=training) # The encoder is handled by the ConvNeXt model \o/ # shape_ksize = result.shape[1] diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index 74b22a6..e2b9a44 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -1,7 +1,6 @@ import os import math import json -from socket import if_nameindex from loguru import logger @@ -28,7 +27,7 @@ def parse_item(item): # TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here # TODO: Any other additional parsing here, since multiple .map() calls are not optimal - return rainfall, water + return ((rainfall, water), tf.ones(1)) def make_dataset(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64): return tf.data.TFRecordDataset(filenames,