implement ability to embed & plot pretrained embeddings

This commit is contained in:
Starbeamrainbowlabs 2022-09-13 19:18:59 +01:00
parent 7130c4fdf8
commit 7685ec3e8b
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
5 changed files with 96 additions and 20 deletions

View file

@ -89,7 +89,6 @@ class RainfallWaterContraster(object):
)
def embed(self, dataset):
result = []
i_batch = -1
for batch in dataset:
i_batch += 1
@ -98,10 +97,9 @@ class RainfallWaterContraster(object):
rainfall = tf.unstack(rainfall, axis=0)
water = tf.unstack(water, axis=0)
result.extend(zip(rainfall, water))
for step in zip(rainfall, water):
yield step
return result
def embed_rainfall(self, dataset):
result = []

View file

@ -43,10 +43,10 @@ def parse_item(metadata, shape_water_desired):
return tf.function(parse_item_inner)
def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
if "NO_PREFETCH" in os.environ:
logger.info("disabling data prefetching.")
return tf.data.TFRecordDataset(filenames,
return tf.data.TFRecordDataset(filepaths,
compression_type=compression_type,
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
).shuffle(shuffle_buffer_size) \
@ -55,11 +55,14 @@ def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
filepaths = shuffle(list(filter(
def get_filepaths(dirpath_input):
return shuffle(list(filter(
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
)))
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
filepaths = get_filepaths(dirpath_input)
filepaths_count = len(filepaths)
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
@ -73,9 +76,18 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
return dataset_train, dataset_validate #, filepaths
def dataset_predict():
raise NotImplementedError("Not implemented yet")
def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5):
filepaths = get_filepaths(dirpath_input)
filepaths_count = len(filepaths)
for i in range(len(filepaths)):
filepaths.append(filepaths[-1])
return make_dataset(
filepaths=filepaths,
metadata=read_metadata(dirpath_input),
batch_size=batch_size,
parallel_reads_multiplier=parallel_reads_multiplier
), filepaths[0:filepaths_count], filepaths_count
if __name__ == "__main__":
ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/")

View file

@ -0,0 +1,9 @@
import io
import gzip
def handle_open(filepath, mode):
if filepath.endswith(".gz"):
return gzip.open(filepath, mode)
else:
return io.open(filepath, mode)

View file

@ -0,0 +1,59 @@
import json
import os
import sys
import argparse
from loguru import logger
import tensorflow as tf
import numpy as np
from lib.io.handle_open import handle_open
from lib.vis.embeddings import vis_embeddings
def parse_args():
parser = argparse.ArgumentParser(description="Plot embeddings predicted by the contrastive learning pretrained model with UMAP.")
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
parser.add_argument("--input", "-i", help="Path to input file containing the content to plot.", required=True)
parser.add_argument("--output", "-o", help="Path to output file to write the resulting image to.", required=True)
parser.add_argument("--only-gpu",
help="If the GPU is not available, exit with an error (useful on shared HPC systems to avoid running out of memory & affecting other users)", action="store_true")
return parser
def run(args):
# Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/
if not os.path.exists(args.input):
raise Exception(f"Error: The specified input filepath ('{args.input}) does not exist.")
filepath_input = args.input
stem, ext = os.path.splitext(args.output)
filepath_output_rainfall = f"{stem}-rainfall.{ext}"
filepath_output_water = f"{stem}-water.{ext}"
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
embeddings = []
with handle_open(filepath_input, "w") as handle:
for line in handle:
obj = json.loads(line)
embeddings.append(obj["rainfall"])
logger.info(">>> Plotting rainfall with UMAP\n")
vis_embeddings(filepath_output_rainfall, np.array(embeddings))
embeddings = []
with handle_open(filepath_input, "w") as handle:
for line in handle:
obj = json.loads(line)
embeddings.append(obj["water"])
logger.info(">>> Plotting water with UMAP\n")
vis_embeddings(filepath_output_water, np.array(embeddings))
sys.stderr.write(">>> Complete\n")

View file

@ -52,12 +52,11 @@ def run(args):
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
dataset_train, filepaths, filepaths_length = dataset_predict(
dataset = dataset_predict(
dirpath_input=args.input,
batch_size=ai.batch_size,
parallel_reads_multiplier=args.read_multiplier
)
filepaths = filepaths[0:filepaths_length]
# for items in dataset_train.repeat(10):
# print("ITEMS", len(items))
@ -69,18 +68,17 @@ def run(args):
if filepath_output != "-":
handle = io.open(filepath_output, "w")
embeddings = ai.embed(dataset_train)[0:filepaths_length] # Trim off the padding
result = list(zip(filepaths, embeddings))
for filepath, embedding in result:
for rainfall, water in ai.embed(dataset):
handle.write(json.dumps({
"filepath": filepath,
"embedding": embedding.numpy().tolist()
"rainfall": rainfall.numpy().tolist(),
"water": water.numpy().tolist()
}, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
if filepath_output != "-":
handle.close()
sys.stderr.write(">>> Plotting with UMAP\n")
filepath_output_umap = os.path.splitext(filepath_output)[0]+'.png'
labels = [ os.path.basename(os.path.dirname(filepath)) for filepath in filepaths ]
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]), np.array(labels))
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]))
sys.stderr.write(">>> Complete\n")