mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
ai: add dummy label
This commit is contained in:
parent
17d42fe899
commit
f1d7973f22
2 changed files with 3 additions and 4 deletions
|
@ -52,8 +52,8 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||||
# super().build(input_shape=input_shape[0])
|
# super().build(input_shape=input_shape[0])
|
||||||
# self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
|
# self.embedding.build(input_shape=tf.TensorShape([ *self.embedding_input_shape ]))
|
||||||
|
|
||||||
def call(self, input_thing):
|
def call(self, input_thing, training=False):
|
||||||
result = self.encoder(input_thing)
|
result = self.encoder(input_thing, training=training)
|
||||||
|
|
||||||
# The encoder is handled by the ConvNeXt model \o/
|
# The encoder is handled by the ConvNeXt model \o/
|
||||||
# shape_ksize = result.shape[1]
|
# shape_ksize = result.shape[1]
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
from socket import if_nameindex
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -28,7 +27,7 @@ def parse_item(item):
|
||||||
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
# TODO: The shape of the resulting tensor can't be statically determined, so we need to reshape here
|
||||||
|
|
||||||
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
# TODO: Any other additional parsing here, since multiple .map() calls are not optimal
|
||||||
return rainfall, water
|
return ((rainfall, water), tf.ones(1))
|
||||||
|
|
||||||
def make_dataset(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
def make_dataset(filenames, compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||||
return tf.data.TFRecordDataset(filenames,
|
return tf.data.TFRecordDataset(filenames,
|
||||||
|
|
Loading…
Reference in a new issue