mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-07-04 20:24:55 +00:00
86 lines
3.9 KiB
Python
86 lines
3.9 KiB
Python
|
import io
|
||
|
import json
|
||
|
import os
|
||
|
import sys
|
||
|
import argparse
|
||
|
import re
|
||
|
|
||
|
from loguru import logger
|
||
|
import tensorflow as tf
|
||
|
import numpy as np
|
||
|
|
||
|
from lib.ai.RainfallWaterContraster import RainfallWaterContraster
|
||
|
from lib.dataset.dataset import dataset_predict
|
||
|
from lib.io.find_paramsjson import find_paramsjson
|
||
|
from lib.io.readfile import readfile
|
||
|
from lib.vis.embeddings import vis_embeddings
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(description="Output feature maps using a given pretrained contrastive model.")
|
||
|
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
||
|
parser.add_argument("--input", "-i", help="Path to input directory containing the images to predict for.", required=True)
|
||
|
parser.add_argument("--output", "-o", help="Path to output file to write output to. Defaults to stdout, but if specified a UMAP graph will NOT be produced.")
|
||
|
parser.add_argument("--checkpoint", "-c", help="Checkpoint file to load model weights from.", required=True)
|
||
|
parser.add_argument("--params", "-p", help="Optional. The file containing the model hyperparameters (usually called 'params.json'). If not specified, it's location will be determined automatically.")
|
||
|
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5). Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||
|
parser.add_argument("--no-vis",
|
||
|
help="Don't also plot a visualisation of the resulting embeddings.", action="store_true")
|
||
|
parser.add_argument("--only-gpu",
|
||
|
help="If the GPU is not available, exit with an error (useful on shared HPC systems to avoid running out of memory & affecting other users)", action="store_true")
|
||
|
|
||
|
return parser
|
||
|
|
||
|
def run(args):
|
||
|
|
||
|
# Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/
|
||
|
|
||
|
if (not hasattr(args, "params")) or args.params == None:
|
||
|
args.params = find_paramsjson(args.checkpoint)
|
||
|
if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None:
|
||
|
args.read_multiplier = 1.5
|
||
|
|
||
|
if not os.path.exists(args.params):
|
||
|
raise Exception(f"Error: The specified filepath params.json hyperparameters ('{args.params}) does not exist.")
|
||
|
if not os.path.exists(args.checkpoint):
|
||
|
raise Exception(f"Error: The specified filepath to the checkpoint to load ('{args.checkpoint}) does not exist.")
|
||
|
|
||
|
|
||
|
filepath_output = args.output if hasattr(args, "output") and args.output != None else "-"
|
||
|
|
||
|
|
||
|
ai = RainfallWaterContraster.from_checkpoint(args.checkpoint)
|
||
|
|
||
|
sys.stderr.write(f"\n\n>>> This is TensorFlow {tf.__version__}\n\n\n")
|
||
|
|
||
|
dataset_train, filepaths, filepaths_length = dataset_predict(
|
||
|
dirpath_input=args.input,
|
||
|
batch_size=ai.batch_size,
|
||
|
parallel_reads_multiplier=args.read_multiplier
|
||
|
)
|
||
|
filepaths = filepaths[0:filepaths_length]
|
||
|
|
||
|
# for items in dataset_train.repeat(10):
|
||
|
# print("ITEMS", len(items))
|
||
|
# print("LEFT", [ item.shape for item in items[0] ])
|
||
|
# print("ITEMS DONE")
|
||
|
# exit(0)
|
||
|
|
||
|
handle = sys.stdout
|
||
|
if filepath_output != "-":
|
||
|
handle = io.open(filepath_output, "w")
|
||
|
|
||
|
embeddings = ai.embed(dataset_train)[0:filepaths_length] # Trim off the padding
|
||
|
result = list(zip(filepaths, embeddings))
|
||
|
for filepath, embedding in result:
|
||
|
handle.write(json.dumps({
|
||
|
"filepath": filepath,
|
||
|
"embedding": embedding.numpy().tolist()
|
||
|
}, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||
|
|
||
|
if filepath_output != "-":
|
||
|
sys.stderr.write(">>> Plotting with UMAP\n")
|
||
|
filepath_output_umap = os.path.splitext(filepath_output)[0]+'.png'
|
||
|
labels = [ os.path.basename(os.path.dirname(filepath)) for filepath in filepaths ]
|
||
|
vis_embeddings(filepath_output_umap, np.array([ embedding.numpy() for embedding in embeddings ]), np.array(labels))
|
||
|
|
||
|
sys.stderr.write(">>> Complete\n")
|