mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
deeplabv3+ for rainfall
This commit is contained in:
parent
677e39f820
commit
c17e53ca75
1 changed files with 42 additions and 20 deletions
|
@ -15,24 +15,27 @@ import matplotlib.pyplot as plt
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_SIZE = 128 # was 512; 128 is the highest power of 2 that fits the data
|
||||
BATCH_SIZE = int(os.environ["DL_BATCH_SIZE"]) if "DL_BATCH_SIZE" in os.environ else 64
|
||||
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
|
||||
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||
NUM_CLASSES = 2
|
||||
DIR_DATA_TF = os.environ["DL_DATA_DIR_TF"]
|
||||
PATH_HEIGHTMAP = os.environ["DL_PATH_HEIGHTMAP"]
|
||||
NUM_BATCHES = int(os.environ["DL_NUM_BATCHES"] if "DL_NUM_BATCHES" in os.environ else "0")
|
||||
DIR_DATA_TF = os.environ["DATA_DIR_TF"]
|
||||
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
|
||||
PATH_COLOURMAP = os.environ["COLOURMAP"]
|
||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||
|
||||
DIR_OUTPUT=f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||
|
||||
if not os.path.exists(DIR_OUTPUT):
|
||||
os.makedirs(DIR_OUTPUT)
|
||||
|
||||
logger.info("DeepLabv3+ rainfall radar TEST")
|
||||
logger.info("DeepLabV3+ rainfall radar TEST")
|
||||
logger.info(f"> NUM_BATCHES {NUM_BATCHES}")
|
||||
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
|
||||
logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}")
|
||||
logger.info(f"> DL_PATH_HEIGHTMAP {DL_PATH_HEIGHTMAP}")
|
||||
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
|
||||
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
||||
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
|
||||
|
||||
|
||||
dataset_train, dataset_validate = dataset_mono(
|
||||
|
@ -147,6 +150,7 @@ history = model.fit(train_dataset,
|
|||
separator="\t"
|
||||
)
|
||||
],
|
||||
steps_per_epoch=STEPS_PER_EPOCH,
|
||||
)
|
||||
logger.info(">>> Training complete")
|
||||
logger.info(">>> Plotting graphs")
|
||||
|
@ -189,16 +193,16 @@ plt.close()
|
|||
|
||||
# Loading the Colormap
|
||||
colormap = loadmat(
|
||||
os.path.join(os.path.dirname(DATA_DIR), "human_colormap.mat")
|
||||
PATH_COLOURMAP
|
||||
)["colormap"]
|
||||
colormap = colormap * 100
|
||||
colormap = colormap.astype(np.uint8)
|
||||
|
||||
|
||||
def infer(model, image_tensor):
|
||||
predictions = model.predict(np.expand_dims((image_tensor), axis=0))
|
||||
predictions = np.squeeze(predictions)
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
|
||||
predictions = tf.squeeze(predictions)
|
||||
predictions = tf.argmax(predictions, axis=2)
|
||||
return predictions
|
||||
|
||||
|
||||
|
@ -232,18 +236,36 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
|
|||
plt.savefig(filepath)
|
||||
|
||||
|
||||
def plot_predictions(filepath, images_list, colormap, model):
|
||||
for image_file in images_list:
|
||||
image_tensor = read_image(image_file)
|
||||
prediction_mask = infer(image_tensor=image_tensor, model=model)
|
||||
def plot_predictions(filepath, input_items, colormap, model):
|
||||
for input_tensor in input_items:
|
||||
prediction_mask = infer(image_tensor=input_tensor, model=model)
|
||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
|
||||
overlay = get_overlay(image_tensor, prediction_colormap)
|
||||
overlay = get_overlay(input_tensor, prediction_colormap)
|
||||
plot_samples_matplotlib(
|
||||
filepath,
|
||||
[image_tensor, overlay, prediction_colormap],
|
||||
[input_tensor, overlay, prediction_colormap],
|
||||
figsize=(18, 14)
|
||||
)
|
||||
|
||||
def get_items_from_batched(dataset, count):
|
||||
result = []
|
||||
for batched in dataset:
|
||||
items = tf.unstack(batched, axis=0)
|
||||
for item in items:
|
||||
result.append(item)
|
||||
if len(result) >= count:
|
||||
return result
|
||||
|
||||
plot_predictions(os.path.join(DIR_OUTPUT, "predict_train.png"), train_images[:4], colormap, model=model)
|
||||
plot_predictions(os.path.join(DIR_OUTPUT, "predict_validate.png"), val_images[:4], colormap, model=model)
|
||||
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||
get_items_from_batched(dataset_train, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||
get_items_from_batched(dataset_validate, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue