deeplabv3+ for rainfall

This commit is contained in:
Starbeamrainbowlabs 2022-12-16 19:52:59 +00:00
parent 677e39f820
commit c17e53ca75
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -15,24 +15,27 @@ import matplotlib.pyplot as plt
import tensorflow as tf
IMAGE_SIZE = 128 # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = int(os.environ["DL_BATCH_SIZE"]) if "DL_BATCH_SIZE" in os.environ else 64
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
NUM_CLASSES = 2
DIR_DATA_TF = os.environ["DL_DATA_DIR_TF"]
PATH_HEIGHTMAP = os.environ["DL_PATH_HEIGHTMAP"]
NUM_BATCHES = int(os.environ["DL_NUM_BATCHES"] if "DL_NUM_BATCHES" in os.environ else "0")
DIR_DATA_TF = os.environ["DATA_DIR_TF"]
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
PATH_COLOURMAP = os.environ["COLOURMAP"]
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
DIR_OUTPUT=f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
if not os.path.exists(DIR_OUTPUT):
os.makedirs(DIR_OUTPUT)
logger.info("DeepLabv3+ rainfall radar TEST")
logger.info("DeepLabV3+ rainfall radar TEST")
logger.info(f"> NUM_BATCHES {NUM_BATCHES}")
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}")
logger.info(f"> DL_PATH_HEIGHTMAP {DL_PATH_HEIGHTMAP}")
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
dataset_train, dataset_validate = dataset_mono(
@ -147,6 +150,7 @@ history = model.fit(train_dataset,
separator="\t"
)
],
steps_per_epoch=STEPS_PER_EPOCH,
)
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
@ -189,16 +193,16 @@ plt.close()
# Loading the Colormap
colormap = loadmat(
os.path.join(os.path.dirname(DATA_DIR), "human_colormap.mat")
PATH_COLOURMAP
)["colormap"]
colormap = colormap * 100
colormap = colormap.astype(np.uint8)
def infer(model, image_tensor):
predictions = model.predict(np.expand_dims((image_tensor), axis=0))
predictions = np.squeeze(predictions)
predictions = np.argmax(predictions, axis=2)
predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
predictions = tf.squeeze(predictions)
predictions = tf.argmax(predictions, axis=2)
return predictions
@ -232,18 +236,36 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
plt.savefig(filepath)
def plot_predictions(filepath, images_list, colormap, model):
for image_file in images_list:
image_tensor = read_image(image_file)
prediction_mask = infer(image_tensor=image_tensor, model=model)
def plot_predictions(filepath, input_items, colormap, model):
for input_tensor in input_items:
prediction_mask = infer(image_tensor=input_tensor, model=model)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
overlay = get_overlay(image_tensor, prediction_colormap)
overlay = get_overlay(input_tensor, prediction_colormap)
plot_samples_matplotlib(
filepath,
[image_tensor, overlay, prediction_colormap],
[input_tensor, overlay, prediction_colormap],
figsize=(18, 14)
)
def get_items_from_batched(dataset, count):
result = []
for batched in dataset:
items = tf.unstack(batched, axis=0)
for item in items:
result.append(item)
if len(result) >= count:
return result
plot_predictions(os.path.join(DIR_OUTPUT, "predict_train.png"), train_images[:4], colormap, model=model)
plot_predictions(os.path.join(DIR_OUTPUT, "predict_validate.png"), val_images[:4], colormap, model=model)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train.png"),
get_items_from_batched(dataset_train, 4),
colormap,
model=model
)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate.png"),
get_items_from_batched(dataset_validate, 4),
colormap,
model=model
)