mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 18:33:01 +00:00
implement ability to embed & plot pretrained embeddings
This commit is contained in:
parent
7130c4fdf8
commit
7685ec3e8b
5 changed files with 96 additions and 20 deletions
|
@ -89,7 +89,6 @@ class RainfallWaterContraster(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed(self, dataset):
|
def embed(self, dataset):
|
||||||
result = []
|
|
||||||
i_batch = -1
|
i_batch = -1
|
||||||
for batch in dataset:
|
for batch in dataset:
|
||||||
i_batch += 1
|
i_batch += 1
|
||||||
|
@ -98,10 +97,9 @@ class RainfallWaterContraster(object):
|
||||||
|
|
||||||
rainfall = tf.unstack(rainfall, axis=0)
|
rainfall = tf.unstack(rainfall, axis=0)
|
||||||
water = tf.unstack(water, axis=0)
|
water = tf.unstack(water, axis=0)
|
||||||
|
for step in zip(rainfall, water):
|
||||||
|
yield step
|
||||||
|
|
||||||
result.extend(zip(rainfall, water))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def embed_rainfall(self, dataset):
|
def embed_rainfall(self, dataset):
|
||||||
result = []
|
result = []
|
||||||
|
|
|
@ -43,10 +43,10 @@ def parse_item(metadata, shape_water_desired):
|
||||||
|
|
||||||
return tf.function(parse_item_inner)
|
return tf.function(parse_item_inner)
|
||||||
|
|
||||||
def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||||
if "NO_PREFETCH" in os.environ:
|
if "NO_PREFETCH" in os.environ:
|
||||||
logger.info("disabling data prefetching.")
|
logger.info("disabling data prefetching.")
|
||||||
return tf.data.TFRecordDataset(filenames,
|
return tf.data.TFRecordDataset(filepaths,
|
||||||
compression_type=compression_type,
|
compression_type=compression_type,
|
||||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||||
).shuffle(shuffle_buffer_size) \
|
).shuffle(shuffle_buffer_size) \
|
||||||
|
@ -55,11 +55,14 @@ def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression
|
||||||
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
|
||||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
def get_filepaths(dirpath_input):
|
||||||
filepaths = shuffle(list(filter(
|
return shuffle(list(filter(
|
||||||
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||||
)))
|
)))
|
||||||
|
|
||||||
|
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||||
|
filepaths = get_filepaths(dirpath_input)
|
||||||
filepaths_count = len(filepaths)
|
filepaths_count = len(filepaths)
|
||||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||||
|
|
||||||
|
@ -73,9 +76,18 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
||||||
|
|
||||||
return dataset_train, dataset_validate #, filepaths
|
return dataset_train, dataset_validate #, filepaths
|
||||||
|
|
||||||
def dataset_predict():
|
def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5):
|
||||||
raise NotImplementedError("Not implemented yet")
|
filepaths = get_filepaths(dirpath_input)
|
||||||
|
filepaths_count = len(filepaths)
|
||||||
|
for i in range(len(filepaths)):
|
||||||
|
filepaths.append(filepaths[-1])
|
||||||
|
|
||||||
|
return make_dataset(
|
||||||
|
filepaths=filepaths,
|
||||||
|
metadata=read_metadata(dirpath_input),
|
||||||
|
batch_size=batch_size,
|
||||||
|
parallel_reads_multiplier=parallel_reads_multiplier
|
||||||
|
), filepaths[0:filepaths_count], filepaths_count
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/")
|
ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/")
|
||||||
|
|
9
aimodel/src/lib/io/handle_open.py
Normal file
9
aimodel/src/lib/io/handle_open.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
import io
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
|
||||||
|
def handle_open(filepath, mode):
|
||||||
|
if filepath.endswith(".gz"):
|
||||||
|
return gzip.open(filepath, mode)
|
||||||
|
else:
|
||||||
|
return io.open(filepath, mode)
|
59
aimodel/src/subcommands/pretrain_plot.py
Normal file
59
aimodel/src/subcommands/pretrain_plot.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lib.io.handle_open import handle_open
|
||||||
|
from lib.vis.embeddings import vis_embeddings
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Plot embeddings predicted by the contrastive learning pretrained model with UMAP.")
|
||||||
|
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
||||||
|
parser.add_argument("--input", "-i", help="Path to input file containing the content to plot.", required=True)
|
||||||
|
parser.add_argument("--output", "-o", help="Path to output file to write the resulting image to.", required=True)
|
||||||
|
parser.add_argument("--only-gpu",
|
||||||
|
help="If the GPU is not available, exit with an error (useful on shared HPC systems to avoid running out of memory & affecting other users)", action="store_true")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def run(args):
|
||||||
|
|
||||||
|
# Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.exists(args.input):
|
||||||
|
raise Exception(f"Error: The specified input filepath ('{args.input}) does not exist.")
|
||||||
|
|
||||||
|
filepath_input = args.input
|
||||||
|
|
||||||
|
stem, ext = os.path.splitext(args.output)
|
||||||
|
filepath_output_rainfall = f"{stem}-rainfall.{ext}"
|
||||||
|
filepath_output_water = f"{stem}-water.{ext}"
|
||||||
|
|
||||||
|
|
||||||
|
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
with handle_open(filepath_input, "w") as handle:
|
||||||
|
for line in handle:
|
||||||
|
obj = json.loads(line)
|
||||||
|
embeddings.append(obj["rainfall"])
|
||||||
|
|
||||||
|
logger.info(">>> Plotting rainfall with UMAP\n")
|
||||||
|
vis_embeddings(filepath_output_rainfall, np.array(embeddings))
|
||||||
|
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
with handle_open(filepath_input, "w") as handle:
|
||||||
|
for line in handle:
|
||||||
|
obj = json.loads(line)
|
||||||
|
embeddings.append(obj["water"])
|
||||||
|
|
||||||
|
logger.info(">>> Plotting water with UMAP\n")
|
||||||
|
vis_embeddings(filepath_output_water, np.array(embeddings))
|
||||||
|
|
||||||
|
sys.stderr.write(">>> Complete\n")
|
|
@ -52,12 +52,11 @@ def run(args):
|
||||||
|
|
||||||
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||||||
|
|
||||||
dataset_train, filepaths, filepaths_length = dataset_predict(
|
dataset = dataset_predict(
|
||||||
dirpath_input=args.input,
|
dirpath_input=args.input,
|
||||||
batch_size=ai.batch_size,
|
batch_size=ai.batch_size,
|
||||||
parallel_reads_multiplier=args.read_multiplier
|
parallel_reads_multiplier=args.read_multiplier
|
||||||
)
|
)
|
||||||
filepaths = filepaths[0:filepaths_length]
|
|
||||||
|
|
||||||
# for items in dataset_train.repeat(10):
|
# for items in dataset_train.repeat(10):
|
||||||
# print("ITEMS", len(items))
|
# print("ITEMS", len(items))
|
||||||
|
@ -69,18 +68,17 @@ def run(args):
|
||||||
if filepath_output != "-":
|
if filepath_output != "-":
|
||||||
handle = io.open(filepath_output, "w")
|
handle = io.open(filepath_output, "w")
|
||||||
|
|
||||||
embeddings = ai.embed(dataset_train)[0:filepaths_length] # Trim off the padding
|
for rainfall, water in ai.embed(dataset):
|
||||||
result = list(zip(filepaths, embeddings))
|
|
||||||
for filepath, embedding in result:
|
|
||||||
handle.write(json.dumps({
|
handle.write(json.dumps({
|
||||||
"filepath": filepath,
|
"rainfall": rainfall.numpy().tolist(),
|
||||||
"embedding": embedding.numpy().tolist()
|
"water": water.numpy().tolist()
|
||||||
}, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
}, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||||||
|
|
||||||
if filepath_output != "-":
|
if filepath_output != "-":
|
||||||
|
handle.close()
|
||||||
|
|
||||||
sys.stderr.write(">>> Plotting with UMAP\n")
|
sys.stderr.write(">>> Plotting with UMAP\n")
|
||||||
filepath_output_umap = os.path.splitext(filepath_output)[0]+'.png'
|
filepath_output_umap = os.path.splitext(filepath_output)[0]+'.png'
|
||||||
labels = [ os.path.basename(os.path.dirname(filepath)) for filepath in filepaths ]
|
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]))
|
||||||
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]), np.array(labels))
|
|
||||||
|
|
||||||
sys.stderr.write(">>> Complete\n")
|
sys.stderr.write(">>> Complete\n")
|
Loading…
Reference in a new issue