mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
implement ability to embed & plot pretrained embeddings
This commit is contained in:
parent
7130c4fdf8
commit
7685ec3e8b
5 changed files with 96 additions and 20 deletions
|
@ -89,7 +89,6 @@ class RainfallWaterContraster(object):
|
|||
)
|
||||
|
||||
def embed(self, dataset):
|
||||
result = []
|
||||
i_batch = -1
|
||||
for batch in dataset:
|
||||
i_batch += 1
|
||||
|
@ -98,10 +97,9 @@ class RainfallWaterContraster(object):
|
|||
|
||||
rainfall = tf.unstack(rainfall, axis=0)
|
||||
water = tf.unstack(water, axis=0)
|
||||
for step in zip(rainfall, water):
|
||||
yield step
|
||||
|
||||
result.extend(zip(rainfall, water))
|
||||
|
||||
return result
|
||||
|
||||
def embed_rainfall(self, dataset):
|
||||
result = []
|
||||
|
|
|
@ -43,10 +43,10 @@ def parse_item(metadata, shape_water_desired):
|
|||
|
||||
return tf.function(parse_item_inner)
|
||||
|
||||
def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||
def make_dataset(filepaths, metadata, shape_watch_desired=[100,100], compression_type="GZIP", parallel_reads_multiplier=1.5, shuffle_buffer_size=128, batch_size=64):
|
||||
if "NO_PREFETCH" in os.environ:
|
||||
logger.info("disabling data prefetching.")
|
||||
return tf.data.TFRecordDataset(filenames,
|
||||
return tf.data.TFRecordDataset(filepaths,
|
||||
compression_type=compression_type,
|
||||
num_parallel_reads=math.ceil(os.cpu_count() * parallel_reads_multiplier)
|
||||
).shuffle(shuffle_buffer_size) \
|
||||
|
@ -55,11 +55,14 @@ def make_dataset(filenames, metadata, shape_watch_desired=[100,100], compression
|
|||
.prefetch(0 if "NO_PREFETCH" in os.environ else tf.data.AUTOTUNE)
|
||||
|
||||
|
||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||
filepaths = shuffle(list(filter(
|
||||
def get_filepaths(dirpath_input):
|
||||
return shuffle(list(filter(
|
||||
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||
)))
|
||||
|
||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||
filepaths = get_filepaths(dirpath_input)
|
||||
filepaths_count = len(filepaths)
|
||||
dataset_splitpoint = math.floor(filepaths_count * train_percentage)
|
||||
|
||||
|
@ -73,9 +76,18 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
|||
|
||||
return dataset_train, dataset_validate #, filepaths
|
||||
|
||||
def dataset_predict():
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
def dataset_predict(dirpath_input, batch_size=64, parallel_reads_multiplier=1.5):
|
||||
filepaths = get_filepaths(dirpath_input)
|
||||
filepaths_count = len(filepaths)
|
||||
for i in range(len(filepaths)):
|
||||
filepaths.append(filepaths[-1])
|
||||
|
||||
return make_dataset(
|
||||
filepaths=filepaths,
|
||||
metadata=read_metadata(dirpath_input),
|
||||
batch_size=batch_size,
|
||||
parallel_reads_multiplier=parallel_reads_multiplier
|
||||
), filepaths[0:filepaths_count], filepaths_count
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds_train, ds_validate = dataset("/mnt/research-data/main/rainfallwater_records-viperfinal/")
|
||||
|
|
9
aimodel/src/lib/io/handle_open.py
Normal file
9
aimodel/src/lib/io/handle_open.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
import io
|
||||
import gzip
|
||||
|
||||
|
||||
def handle_open(filepath, mode):
|
||||
if filepath.endswith(".gz"):
|
||||
return gzip.open(filepath, mode)
|
||||
else:
|
||||
return io.open(filepath, mode)
|
59
aimodel/src/subcommands/pretrain_plot.py
Normal file
59
aimodel/src/subcommands/pretrain_plot.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from loguru import logger
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from lib.io.handle_open import handle_open
|
||||
from lib.vis.embeddings import vis_embeddings
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Plot embeddings predicted by the contrastive learning pretrained model with UMAP.")
|
||||
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
||||
parser.add_argument("--input", "-i", help="Path to input file containing the content to plot.", required=True)
|
||||
parser.add_argument("--output", "-o", help="Path to output file to write the resulting image to.", required=True)
|
||||
parser.add_argument("--only-gpu",
|
||||
help="If the GPU is not available, exit with an error (useful on shared HPC systems to avoid running out of memory & affecting other users)", action="store_true")
|
||||
|
||||
return parser
|
||||
|
||||
def run(args):
|
||||
|
||||
# Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/
|
||||
|
||||
|
||||
if not os.path.exists(args.input):
|
||||
raise Exception(f"Error: The specified input filepath ('{args.input}) does not exist.")
|
||||
|
||||
filepath_input = args.input
|
||||
|
||||
stem, ext = os.path.splitext(args.output)
|
||||
filepath_output_rainfall = f"{stem}-rainfall.{ext}"
|
||||
filepath_output_water = f"{stem}-water.{ext}"
|
||||
|
||||
|
||||
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||||
|
||||
embeddings = []
|
||||
with handle_open(filepath_input, "w") as handle:
|
||||
for line in handle:
|
||||
obj = json.loads(line)
|
||||
embeddings.append(obj["rainfall"])
|
||||
|
||||
logger.info(">>> Plotting rainfall with UMAP\n")
|
||||
vis_embeddings(filepath_output_rainfall, np.array(embeddings))
|
||||
|
||||
|
||||
embeddings = []
|
||||
with handle_open(filepath_input, "w") as handle:
|
||||
for line in handle:
|
||||
obj = json.loads(line)
|
||||
embeddings.append(obj["water"])
|
||||
|
||||
logger.info(">>> Plotting water with UMAP\n")
|
||||
vis_embeddings(filepath_output_water, np.array(embeddings))
|
||||
|
||||
sys.stderr.write(">>> Complete\n")
|
|
@ -52,12 +52,11 @@ def run(args):
|
|||
|
||||
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||||
|
||||
dataset_train, filepaths, filepaths_length = dataset_predict(
|
||||
dataset = dataset_predict(
|
||||
dirpath_input=args.input,
|
||||
batch_size=ai.batch_size,
|
||||
parallel_reads_multiplier=args.read_multiplier
|
||||
)
|
||||
filepaths = filepaths[0:filepaths_length]
|
||||
|
||||
# for items in dataset_train.repeat(10):
|
||||
# print("ITEMS", len(items))
|
||||
|
@ -69,18 +68,17 @@ def run(args):
|
|||
if filepath_output != "-":
|
||||
handle = io.open(filepath_output, "w")
|
||||
|
||||
embeddings = ai.embed(dataset_train)[0:filepaths_length] # Trim off the padding
|
||||
result = list(zip(filepaths, embeddings))
|
||||
for filepath, embedding in result:
|
||||
for rainfall, water in ai.embed(dataset):
|
||||
handle.write(json.dumps({
|
||||
"filepath": filepath,
|
||||
"embedding": embedding.numpy().tolist()
|
||||
"rainfall": rainfall.numpy().tolist(),
|
||||
"water": water.numpy().tolist()
|
||||
}, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||||
|
||||
if filepath_output != "-":
|
||||
handle.close()
|
||||
|
||||
sys.stderr.write(">>> Plotting with UMAP\n")
|
||||
filepath_output_umap = os.path.splitext(filepath_output)[0]+'.png'
|
||||
labels = [ os.path.basename(os.path.dirname(filepath)) for filepath in filepaths ]
|
||||
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]), np.array(labels))
|
||||
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]))
|
||||
|
||||
sys.stderr.write(">>> Complete\n")
|
Loading…
Reference in a new issue