mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
deeplabv3+ for rainfall
This commit is contained in:
parent
677e39f820
commit
c17e53ca75
1 changed files with 42 additions and 20 deletions
|
@ -15,24 +15,27 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
IMAGE_SIZE = 128 # was 512; 128 is the highest power of 2 that fits the data
|
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
|
||||||
BATCH_SIZE = int(os.environ["DL_BATCH_SIZE"]) if "DL_BATCH_SIZE" in os.environ else 64
|
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||||
NUM_CLASSES = 2
|
NUM_CLASSES = 2
|
||||||
DIR_DATA_TF = os.environ["DL_DATA_DIR_TF"]
|
DIR_DATA_TF = os.environ["DATA_DIR_TF"]
|
||||||
PATH_HEIGHTMAP = os.environ["DL_PATH_HEIGHTMAP"]
|
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
|
||||||
NUM_BATCHES = int(os.environ["DL_NUM_BATCHES"] if "DL_NUM_BATCHES" in os.environ else "0")
|
PATH_COLOURMAP = os.environ["COLOURMAP"]
|
||||||
|
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||||
|
|
||||||
DIR_OUTPUT=f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||||
|
|
||||||
if not os.path.exists(DIR_OUTPUT):
|
if not os.path.exists(DIR_OUTPUT):
|
||||||
os.makedirs(DIR_OUTPUT)
|
os.makedirs(DIR_OUTPUT)
|
||||||
|
|
||||||
logger.info("DeepLabv3+ rainfall radar TEST")
|
logger.info("DeepLabV3+ rainfall radar TEST")
|
||||||
logger.info(f"> NUM_BATCHES {NUM_BATCHES}")
|
logger.info(f"> NUM_BATCHES {NUM_BATCHES}")
|
||||||
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
|
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
|
||||||
logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}")
|
logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}")
|
||||||
logger.info(f"> DL_PATH_HEIGHTMAP {DL_PATH_HEIGHTMAP}")
|
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||||
|
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
|
||||||
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
||||||
|
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
|
||||||
|
|
||||||
|
|
||||||
dataset_train, dataset_validate = dataset_mono(
|
dataset_train, dataset_validate = dataset_mono(
|
||||||
|
@ -147,6 +150,7 @@ history = model.fit(train_dataset,
|
||||||
separator="\t"
|
separator="\t"
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
steps_per_epoch=STEPS_PER_EPOCH,
|
||||||
)
|
)
|
||||||
logger.info(">>> Training complete")
|
logger.info(">>> Training complete")
|
||||||
logger.info(">>> Plotting graphs")
|
logger.info(">>> Plotting graphs")
|
||||||
|
@ -189,16 +193,16 @@ plt.close()
|
||||||
|
|
||||||
# Loading the Colormap
|
# Loading the Colormap
|
||||||
colormap = loadmat(
|
colormap = loadmat(
|
||||||
os.path.join(os.path.dirname(DATA_DIR), "human_colormap.mat")
|
PATH_COLOURMAP
|
||||||
)["colormap"]
|
)["colormap"]
|
||||||
colormap = colormap * 100
|
colormap = colormap * 100
|
||||||
colormap = colormap.astype(np.uint8)
|
colormap = colormap.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
def infer(model, image_tensor):
|
def infer(model, image_tensor):
|
||||||
predictions = model.predict(np.expand_dims((image_tensor), axis=0))
|
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
|
||||||
predictions = np.squeeze(predictions)
|
predictions = tf.squeeze(predictions)
|
||||||
predictions = np.argmax(predictions, axis=2)
|
predictions = tf.argmax(predictions, axis=2)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,18 +236,36 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
|
||||||
plt.savefig(filepath)
|
plt.savefig(filepath)
|
||||||
|
|
||||||
|
|
||||||
def plot_predictions(filepath, images_list, colormap, model):
|
def plot_predictions(filepath, input_items, colormap, model):
|
||||||
for image_file in images_list:
|
for input_tensor in input_items:
|
||||||
image_tensor = read_image(image_file)
|
prediction_mask = infer(image_tensor=input_tensor, model=model)
|
||||||
prediction_mask = infer(image_tensor=image_tensor, model=model)
|
|
||||||
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
|
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
|
||||||
overlay = get_overlay(image_tensor, prediction_colormap)
|
overlay = get_overlay(input_tensor, prediction_colormap)
|
||||||
plot_samples_matplotlib(
|
plot_samples_matplotlib(
|
||||||
filepath,
|
filepath,
|
||||||
[image_tensor, overlay, prediction_colormap],
|
[input_tensor, overlay, prediction_colormap],
|
||||||
figsize=(18, 14)
|
figsize=(18, 14)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_items_from_batched(dataset, count):
|
||||||
|
result = []
|
||||||
|
for batched in dataset:
|
||||||
|
items = tf.unstack(batched, axis=0)
|
||||||
|
for item in items:
|
||||||
|
result.append(item)
|
||||||
|
if len(result) >= count:
|
||||||
|
return result
|
||||||
|
|
||||||
plot_predictions(os.path.join(DIR_OUTPUT, "predict_train.png"), train_images[:4], colormap, model=model)
|
|
||||||
plot_predictions(os.path.join(DIR_OUTPUT, "predict_validate.png"), val_images[:4], colormap, model=model)
|
plot_predictions(
|
||||||
|
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||||
|
get_items_from_batched(dataset_train, 4),
|
||||||
|
colormap,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
plot_predictions(
|
||||||
|
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||||
|
get_items_from_batched(dataset_validate, 4),
|
||||||
|
colormap,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue