2022-10-18 18:37:55 +00:00
import io
import json
import os
import sys
import argparse
import re
from loguru import logger
import tensorflow as tf
2022-10-20 14:22:29 +00:00
from lib . dataset . batched_iterator import batched_iterator
2022-10-18 18:37:55 +00:00
2022-10-19 16:26:40 +00:00
from lib . vis . segmentation_plot import segmentation_plot
2022-10-18 18:37:55 +00:00
from lib . io . handle_open import handle_open
2022-10-21 13:25:43 +00:00
from lib . ai . RainfallWaterSegmenter import RainfallWaterSegmenter
2022-10-20 14:16:24 +00:00
from lib . dataset . dataset_segmenter import dataset_predict
2022-10-18 18:37:55 +00:00
from lib . io . find_paramsjson import find_paramsjson
from lib . io . readfile import readfile
2022-10-19 16:26:40 +00:00
from lib . vis . segmentation_plot import segmentation_plot
2022-10-18 18:37:55 +00:00
MODE_JSONL = 1
MODE_PNG = 2
def parse_args ( ) :
parser = argparse . ArgumentParser ( description = " Output water depth image segmentation maps using a given pretrained image segmentation model. " )
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
parser . add_argument ( " --input " , " -i " , help = " Path to input directory containing the .tfrecord(.gz) files to predict for. If a single file is passed instead, then only that file will be converted. " , required = True )
parser . add_argument ( " --output " , " -o " , help = " Path to output file to write output to. If the file extension .png is used instead of .jsonl.gz, then an image is written instead (+d is replaced with the item index). " )
parser . add_argument ( " --records-per-file " , help = " Optional, only valid with the .jsonl.gz file extension. If specified, this limits the number of records written to each file. When using this option, you MUST have the string ' +d ' (without quotes) somewhere in your output filepath. " , type = int )
parser . add_argument ( " --checkpoint " , " -c " , help = " Checkpoint file to load model weights from. " , required = True )
parser . add_argument ( " --params " , " -p " , help = " Optional. The file containing the model hyperparameters (usually called ' params.json ' ). If not specified, it ' s location will be determined automatically. " )
2022-10-19 16:26:40 +00:00
parser . add_argument ( " --reads-multiplier " , help = " Optional. The multiplier for the number of files we should read from at once. Defaults to 0. When using this start with 1.5, which means read ceil(NUMBER_OF_CORES * 1.5). Set to a higher number of systems with high read latency to avoid starving the GPU of data. SETTING THIS WILL SCRAMBLE THE ORDER OF THE DATASET. " )
parser . add_argument ( " --model-code " , help = " A description of the model used to predict the data. Will be inserted in the title of png plots. " )
2022-10-21 15:53:08 +00:00
parser . add_argument ( " --log " , help = " Optional. If specified when the file extension is .jsonl[.gz], then this chooses what is logged. Specify a comma separated list of values. Possible values: rainfall_actual, water_actual, water_predict. Default: rainfall_actual,water_actual,water_predict. " )
2022-10-18 18:37:55 +00:00
return parser
def run ( args ) :
# Note that we do NOT check to see if the checkpoint file exists, because Tensorflow/Keras requires that we pass the stem instead of the actual index file..... :-/
if ( not hasattr ( args , " params " ) ) or args . params == None :
args . params = find_paramsjson ( args . checkpoint )
if ( not hasattr ( args , " read_multiplier " ) ) or args . read_multiplier == None :
2022-10-19 16:26:40 +00:00
args . read_multiplier = 0
2022-10-18 18:37:55 +00:00
if ( not hasattr ( args , " records_per_file " ) ) or args . records_per_file == None :
args . records_per_file = 0 # 0 = unlimited
if ( not hasattr ( args , " output " ) ) or args . output == None :
args . output = " - "
2022-10-19 16:26:40 +00:00
if ( not hasattr ( args , " model_code " ) ) or args . model_code == None :
args . model_code = " "
2022-10-21 15:53:08 +00:00
if ( not hasattr ( args , " log " ) ) or args . log == None :
args . log = " rainfall_actual,water_actual,water_predict "
args . log = args . log . strip ( ) . split ( " , " )
2022-10-18 18:37:55 +00:00
if not os . path . exists ( args . params ) :
raise Exception ( f " Error: The specified filepath params.json hyperparameters ( ' { args . params } ) does not exist. " )
if not os . path . exists ( args . checkpoint ) :
raise Exception ( f " Error: The specified filepath to the checkpoint to load ( ' { args . checkpoint } ) does not exist. " )
if args . records_per_file > 0 and args . output . endswith ( " .jsonl.gz " ) :
dirpath_output = os . path . dirname ( args . output )
if not os . path . exists ( dirpath_output ) :
os . mkdir ( dirpath_output )
2022-10-21 15:53:08 +00:00
model_params = json . loads ( readfile ( args . params ) )
ai = RainfallWaterSegmenter . from_checkpoint ( args . checkpoint , * * model_params )
2022-10-18 18:37:55 +00:00
sys . stderr . write ( f " \n \n >>> This is TensorFlow { tf . __version__ } \n \n \n " )
# Note that if using a directory of input files, the output order is NOT GUARANTEED TO BE THE SAME. In fact, it probably won't be.
dataset = dataset_predict (
dirpath_input = args . input ,
parallel_reads_multiplier = args . read_multiplier
)
# for items in dataset_train.repeat(10):
# print("ITEMS", len(items))
# print("LEFT", [ item.shape for item in items[0] ])
# print("ITEMS DONE")
# exit(0)
2022-10-19 16:26:40 +00:00
output_mode = MODE_PNG if args . output . endswith ( " .png " ) else MODE_JSONL
2022-10-18 18:37:55 +00:00
logger . info ( " Output mode is " + ( " PNG " if output_mode == MODE_PNG else " JSONL " ) )
logger . info ( f " Records per file: { args . records_per_file } " )
2022-10-19 16:26:40 +00:00
if output_mode == MODE_JSONL :
2022-10-26 16:11:36 +00:00
do_jsonl ( args , ai , dataset , model_params )
2022-10-19 16:26:40 +00:00
else :
2022-10-26 16:11:36 +00:00
do_png ( args , ai , dataset , model_params )
2022-10-18 18:37:55 +00:00
sys . stderr . write ( " >>> Complete \n " )
2022-10-26 16:11:36 +00:00
def do_png ( args , ai , dataset , model_params ) :
2022-10-21 14:27:39 +00:00
if not os . path . exists ( os . path . dirname ( args . output ) ) :
2022-10-21 14:17:39 +00:00
os . mkdir ( os . path . dirname ( args . output ) )
2022-10-19 16:26:40 +00:00
i = 0
2022-10-20 14:42:33 +00:00
gen = batched_iterator ( dataset , tensors_in_item = 2 , batch_size = model_params [ " batch_size " ] )
2022-10-20 18:34:04 +00:00
for item in gen :
rainfall , water = item
2022-10-21 14:15:59 +00:00
2022-10-20 14:22:29 +00:00
water_predict_batch = ai . embed ( rainfall )
2022-10-21 15:53:08 +00:00
water = tf . unstack ( water , axis = 0 )
2022-10-21 14:15:59 +00:00
i_batch = 0
2022-10-20 14:22:29 +00:00
for water_predict in water_predict_batch :
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf . math . argmax ( water_predict , axis = - 1 )
# [ width, height ]
2022-10-21 15:53:08 +00:00
water_actual = tf . squeeze ( water [ i_batch ] )
2022-10-20 14:22:29 +00:00
segmentation_plot (
2022-10-21 14:27:39 +00:00
water_actual , water_predict ,
2022-10-26 16:11:36 +00:00
args . model_code ,
2022-10-20 14:22:29 +00:00
args . output . replace ( " +d " , str ( i ) )
)
2022-10-21 14:15:59 +00:00
i_batch + = 1
2022-10-20 14:22:29 +00:00
i + = 1
if i % 100 == 0 :
2022-10-21 14:35:43 +00:00
sys . stderr . write ( f " Processed { i } items \r " )
2022-10-19 16:26:40 +00:00
2022-10-21 15:53:08 +00:00
def do_jsonl ( args , ai , dataset , model_params ) :
2022-10-18 18:37:55 +00:00
write_mode = " wt " if args . output . endswith ( " .gz " ) else " w "
handle = sys . stdout
filepath_metadata = None
if args . output != " - " :
handle = handle_open (
args . output if args . records_per_file < = 0 else args . output . replace ( " +d " , str ( 0 ) ) ,
write_mode
)
filepath_metadata = os . path . join ( os . path . dirname ( args . output ) , " metadata.json " )
logger . info ( f " filepath_output: { args . output } " )
logger . info ( f " filepath_params: { filepath_metadata } " )
i = 0
i_file = i
files_done = 0
2022-10-21 15:53:08 +00:00
for batch in batched_iterator ( dataset , tensors_in_item = 2 , batch_size = model_params [ " batch_size " ] ) :
rainfall_actual_batch , water_actual_batch = batch
water_predict_batch = ai . embed ( rainfall_actual_batch )
water_actual_batch = tf . unstack ( water_actual_batch , axis = 0 )
rainfall_actual_batch = tf . unstack ( rainfall_actual_batch , axis = 0 )
2022-10-18 18:37:55 +00:00
2022-10-21 15:53:08 +00:00
i_batch = 0
for water_predict in water_predict_batch :
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf . math . argmax ( water_predict , axis = - 1 )
# [ width, height ]
water_actual = tf . squeeze ( water_actual_batch [ i_batch ] )
if args . records_per_file > 0 and i_file > args . records_per_file :
files_done + = 1
i_file = 0
handle . close ( )
logger . info ( f " PROGRESS:file { files_done } " )
handle = handle_open ( args . output . replace ( " +d " , str ( files_done + 1 ) ) , write_mode )
item_obj = { }
if " rainfall_actual " in args . log :
2022-10-26 16:12:36 +00:00
item_obj [ " rainfall_actual " ] = rainfall_actual_batch [ i_batch ] . numpy ( ) . tolist ( )
2022-10-21 15:53:08 +00:00
if " water_actual " in args . log :
2022-10-26 16:12:36 +00:00
item_obj [ " water_actual " ] = water_actual . numpy ( ) . tolist ( )
2022-10-21 15:53:08 +00:00
if " water_predict " in args . log :
2022-10-26 16:12:36 +00:00
item_obj [ " water_predict " ] = water_predict . numpy ( ) . tolist ( )
2022-10-21 15:53:08 +00:00
handle . write ( json . dumps ( item_obj , separators = ( ' , ' , ' : ' ) ) + " \n " ) # Ref https://stackoverflow.com/a/64710892/1460422
if i == 0 or i % 100 == 0 :
sys . stderr . write ( f " [pretrain:predict] STEP { i } \r " )
i_batch + = 1
2022-10-18 18:37:55 +00:00
i + = 1
i_file + = 1
handle . close ( )