mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
finish train_predict
This commit is contained in:
parent
488f78fca5
commit
200076596b
2 changed files with 62 additions and 11 deletions
28
aimodel/src/lib/vis/segmentation_plot.py
Normal file
28
aimodel/src/lib/vis/segmentation_plot.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import matplotlib.pylab as plt
|
||||
|
||||
def segmentation_plot(water_actual, water_predict, model_code, filepath_output):
|
||||
# water_actual = [ width, height ]
|
||||
# water_predict = [ width, height ]
|
||||
|
||||
water_actual = water_actual.numpy()
|
||||
water_predict = water_predict.numpy()
|
||||
|
||||
px = 1 / plt.rcParams['figure.dpi'] # matplotlib sizes are in inches :-( :-( :-(
|
||||
width = 768*2
|
||||
height = 768
|
||||
|
||||
|
||||
plt.rc("font", size=20)
|
||||
plt.rc("font", family="Ubuntu")
|
||||
figure, axes = plt.subplot_mosaic("AB", figsize=(width*px, height*px))
|
||||
|
||||
axes["A"].imshow(water_actual)
|
||||
axes["A"].set_title(f"Actual", fontsize=20)
|
||||
|
||||
|
||||
axes["B"].imshow(water_predict)
|
||||
axes["A"].set_title(f"Predicted", fontsize=20)
|
||||
|
||||
|
||||
plt.suptitle(f"Rainfall → Water depth prediction | {model_code}", fontsize=28, weight="bold")
|
||||
plt.savefig(filepath_output)
|
|
@ -7,15 +7,14 @@ import re
|
|||
|
||||
from loguru import logger
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from lib.io.writefile import writefile
|
||||
from lib.vis.segmentation_plot import segmentation_plot
|
||||
from lib.io.handle_open import handle_open
|
||||
from lib.ai.RainfallWaterContraster import RainfallWaterContraster
|
||||
from lib.dataset.dataset import dataset_predict
|
||||
from lib.io.find_paramsjson import find_paramsjson
|
||||
from lib.io.readfile import readfile
|
||||
from lib.vis.embeddings import vis_embeddings
|
||||
from lib.vis.segmentation_plot import segmentation_plot
|
||||
|
||||
|
||||
MODE_JSONL = 1
|
||||
|
@ -29,8 +28,8 @@ def parse_args():
|
|||
parser.add_argument("--records-per-file", help="Optional, only valid with the .jsonl.gz file extension. If specified, this limits the number of records written to each file. When using this option, you MUST have the string '+d' (without quotes) somewhere in your output filepath.", type=int)
|
||||
parser.add_argument("--checkpoint", "-c", help="Checkpoint file to load model weights from.", required=True)
|
||||
parser.add_argument("--params", "-p", help="Optional. The file containing the model hyperparameters (usually called 'params.json'). If not specified, it's location will be determined automatically.")
|
||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5). Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||
|
||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 0. When using this start with 1.5, which means read ceil(NUMBER_OF_CORES * 1.5). Set to a higher number of systems with high read latency to avoid starving the GPU of data. SETTING THIS WILL SCRAMBLE THE ORDER OF THE DATASET.")
|
||||
parser.add_argument("--model-code", help="A description of the model used to predict the data. Will be inserted in the title of png plots.")
|
||||
return parser
|
||||
|
||||
def run(args):
|
||||
|
@ -40,11 +39,13 @@ def run(args):
|
|||
if (not hasattr(args, "params")) or args.params == None:
|
||||
args.params = find_paramsjson(args.checkpoint)
|
||||
if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None:
|
||||
args.read_multiplier = 1.5
|
||||
args.read_multiplier = 0
|
||||
if (not hasattr(args, "records_per_file")) or args.records_per_file == None:
|
||||
args.records_per_file = 0 # 0 = unlimited
|
||||
if (not hasattr(args, "output")) or args.output == None:
|
||||
args.output = "-"
|
||||
if (not hasattr(args, "model_code")) or args.model_code == None:
|
||||
args.model_code = ""
|
||||
|
||||
if not os.path.exists(args.params):
|
||||
raise Exception(f"Error: The specified filepath params.json hyperparameters ('{args.params}) does not exist.")
|
||||
|
@ -74,17 +75,39 @@ def run(args):
|
|||
# exit(0)
|
||||
|
||||
|
||||
|
||||
output_mode = MODE_PNG if args.output.endswith(".png") else MODE_JSONL
|
||||
logger.info("Output mode is "+("PNG" if output_mode == MODE_PNG else "JSONL"))
|
||||
logger.info(f"Records per file: {args.records_per_file}")
|
||||
|
||||
|
||||
do_jsonl(args, ai, dataset, write_mode)
|
||||
if output_mode == MODE_JSONL:
|
||||
do_jsonl(args, ai, dataset)
|
||||
else:
|
||||
do_png(args, ai, dataset, args.model_code)
|
||||
|
||||
sys.stderr.write(">>> Complete\n")
|
||||
|
||||
def do_png(args, ai, dataset, model_code):
|
||||
i = 0
|
||||
for rainfall, water in dataset:
|
||||
water_predict = ai.embed(rainfall)
|
||||
|
||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||
# [ width, height ]
|
||||
water = tf.squeeze(water)
|
||||
|
||||
segmentation_plot(
|
||||
water, water_predict,
|
||||
model_code,
|
||||
args.output.replace("+d", str(i))
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
sys.stderr.write(f"Processed {i} items")
|
||||
|
||||
def do_jsonl(args, ai, dataset):
|
||||
output_mode = MODE_PNG if args.output.endswith(".png") else MODE_JSONL
|
||||
write_mode = "wt" if args.output.endswith(".gz") else "w"
|
||||
|
||||
handle = sys.stdout
|
||||
|
@ -108,7 +131,7 @@ def do_jsonl(args, ai, dataset):
|
|||
i_file = 0
|
||||
handle.close()
|
||||
logger.info(f"PROGRESS:file {files_done}")
|
||||
handle = handle_open(args.output.replace("+d", str(files_done+1)), write_mode, handle_mode=output_mode)
|
||||
handle = handle_open(args.output.replace("+d", str(files_done+1)), write_mode)
|
||||
|
||||
handle.write(json.dumps(step_rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
|
||||
|
||||
|
|
Loading…
Reference in a new issue