From 7685ec3e8bc1242100792ff75b30bc1c6a08f69c Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 13 Sep 2022 19:18:59 +0100 Subject: [PATCH] implement ability to embed & plot pretrained embeddings --- aimodel/src/lib/ai/RainfallWaterContraster.py | 6 +- aimodel/src/lib/dataset/dataset.py | 26 +++++--- aimodel/src/lib/io/handle_open.py | 9 +++ aimodel/src/subcommands/pretrain_plot.py | 59 +++++++++++++++++++ aimodel/src/subcommands/pretrain_predict.py | 16 +++-- 5 files changed, 96 insertions(+), 20 deletions(-) create mode 100644 aimodel/src/lib/io/handle_open.py create mode 100644 aimodel/src/subcommands/pretrain_plot.py diff --git a/aimodel/src/lib/ai/RainfallWaterContraster.py b/aimodel/src/lib/ai/RainfallWaterContraster.py index 5140f4d..e7953c6 100644 --- a/aimodel/src/lib/ai/RainfallWaterContraster.py +++ b/aimodel/src/lib/ai/RainfallWaterContraster.py @@ -89,7 +89,6 @@ class RainfallWaterContraster(object): ) def embed(self, dataset): - result = [] i_batch = -1 for batch in dataset: i_batch += 1 @@ -98,10 +97,9 @@ class RainfallWaterContraster(object): rainfall = tf.unstack(rainfall, axis=0) water = tf.unstack(water, axis=0) - - result.extend(zip(rainfall, water)) + for step in zip(rainfall, water): + yield step - return result def embed_rainfall(self, dataset): result = [] diff --git a/aimodel/src/lib/dataset/dataset.py b/aimodel/src/lib/dataset/dataset.py index e084207..403c158 100644 --- a/aimodel/src/lib/dataset/dataset.py +++ b/aimodel/src/lib/dataset/dataset.py @@ -43,10 +43,10 @@ def parse_item(metadata, shape_water_desired): return tf.function(parse_item_inner) -def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64): +def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64): if "NO_PREFETCH" in os.environ: logger.info("disabling data prefetching.") - return tf.data.TFRecordDataset(filenames, + return tf.data.TFRecordDataset(filepaths, compression_type=compression_type, num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier) ).shuffle(shuffle_buffer_size) \ @@ -55,11 +55,14 @@ def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression .prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE) -def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5): - filepaths = shuffle(list(filter( +def get_filepaths(dirpath_input): + return shuffle(list(filter( lambda filepath: str(filepath).endswith(".tfrecord.gz"), [ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath ))) + +def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5): + filepaths = get_filepaths(dirpath_input) filepaths_count = len(filepaths) dataset_splitpoint = math.floor(filepaths_count * train_percentage) @@ -73,9 +76,18 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m return dataset_train, dataset_validate #, filepaths -def dataset_predict(): - raise NotImplementedError("Not implemented yet") - +def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5): + filepaths = get_filepaths(dirpath_input) + filepaths_count = len(filepaths) + for i in range(len(filepaths)): + filepaths.append(filepaths[-1]) + + return make_dataset( + filepaths=filepaths, + metadata=read_metadata(dirpath_input), + batch_size=batch_size, + parallel_reads_multiplier=parallel_reads_multiplier + ), filepaths[0:filepaths_count], filepaths_count if __name__ == "__main__": ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/") diff --git a/aimodel/src/lib/io/handle_open.py b/aimodel/src/lib/io/handle_open.py new file mode 100644 index 0000000..a167ea0 --- /dev/null +++ b/aimodel/src/lib/io/handle_open.py @@ -0,0 +1,9 @@ +import io +import gzip + + +def handle_open(filepath, mode): + if filepath.endswith(".gz"): + return gzip.open(filepath, mode) + else: + return io.open(filepath, mode) \ No newline at end of file diff --git a/aimodel/src/subcommands/pretrain_plot.py b/aimodel/src/subcommands/pretrain_plot.py new file mode 100644 index 0000000..33a355c --- /dev/null +++ b/aimodel/src/subcommands/pretrain_plot.py @@ -0,0 +1,59 @@ +import json +import os +import sys +import argparse + +from loguru import logger +import tensorflow as tf +import numpy as np + +from lib.io.handle_open import handle_open +from lib.vis.embeddings import vis_embeddings + +def parse_args(): + parser = argparse.ArgumentParser(description="Plot embeddings predicted by the contrastive learning pretrained model with UMAP.") + # parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True) + parser.add_argument("--input", "-i", help="Path to input file containing the content to plot.", required=True) + parser.add_argument("--output", "-o", help="Path to output file to write the resulting image to.", required=True) + parser.add_argument("--only-gpu", + help="If the GPU is not available, exit with an error (useful on shared HPC systems to avoid running out of memory & affecting other users)", action="store_true") + + return parser + +def run(args): + + # Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/ + + + if not os.path.exists(args.input): + raise Exception(f"Error: The specified input filepath ('{args.input}) does not exist.") + + filepath_input = args.input + + stem, ext = os.path.splitext(args.output) + filepath_output_rainfall = f"{stem}-rainfall.{ext}" + filepath_output_water = f"{stem}-water.{ext}" + + + sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n") + + embeddings = [] + with handle_open(filepath_input, "w") as handle: + for line in handle: + obj = json.loads(line) + embeddings.append(obj["rainfall"]) + + logger.info(">>> Plotting rainfall with UMAP\n") + vis_embeddings(filepath_output_rainfall, np.array(embeddings)) + + + embeddings = [] + with handle_open(filepath_input, "w") as handle: + for line in handle: + obj = json.loads(line) + embeddings.append(obj["water"]) + + logger.info(">>> Plotting water with UMAP\n") + vis_embeddings(filepath_output_water, np.array(embeddings)) + + sys.stderr.write(">>> Complete\n") \ No newline at end of file diff --git a/aimodel/src/subcommands/pretrain_predict.py b/aimodel/src/subcommands/pretrain_predict.py index f6e1ce8..bcea76a 100644 --- a/aimodel/src/subcommands/pretrain_predict.py +++ b/aimodel/src/subcommands/pretrain_predict.py @@ -52,12 +52,11 @@ def run(args): sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n") - dataset_train, filepaths, filepaths_length = dataset_predict( + dataset = dataset_predict( dirpath_input=args.input, batch_size=ai.batch_size, parallel_reads_multiplier=args.read_multiplier ) - filepaths = filepaths[0:filepaths_length] # for items in dataset_train.repeat(10): # print("ITEMS", len(items)) @@ -69,18 +68,17 @@ def run(args): if filepath_output != "-": handle = io.open(filepath_output, "w") - embeddings = ai.embed(dataset_train)[0:filepaths_length] # Trim off the padding - result = list(zip(filepaths, embeddings)) - for filepath, embedding in result: + for rainfall, water in ai.embed(dataset): handle.write(json.dumps({ - "filepath": filepath, - "embedding": embedding.numpy().tolist() + "rainfall": rainfall.numpy().tolist(), + "water": water.numpy().tolist() }, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422 if filepath_output != "-": + handle.close() + sys.stderr.write(">>> Plotting with UMAP\n") filepath_output_umap = os.path.splitext(filepath_output)[0]+'.png' - labels = [ os.path.basename(os.path.dirname(filepath)) for filepath in filepaths ] - vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]), np.array(labels)) + vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ])) sys.stderr.write(">>> Complete\n") \ No newline at end of file