From c17e53ca759af70705eccaf75fffa04fb6d17695 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Fri, 16 Dec 2022 19:52:59 +0000 Subject: [PATCH] deeplabv3+ for rainfall --- aimodel/src/deeplabv3_plus_test_rainfall.py | 62 ++++++++++++++------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 71f4e76..b058377 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -15,24 +15,27 @@ import matplotlib.pyplot as plt import tensorflow as tf -IMAGE_SIZE = 128 # was 512; 128 is the highest power of 2 that fits the data -BATCH_SIZE = int(os.environ["DL_BATCH_SIZE"]) if "DL_BATCH_SIZE" in os.environ else 64 +IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data +BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64 NUM_CLASSES = 2 -DIR_DATA_TF = os.environ["DL_DATA_DIR_TF"] -PATH_HEIGHTMAP = os.environ["DL_PATH_HEIGHTMAP"] -NUM_BATCHES = int(os.environ["DL_NUM_BATCHES"] if "DL_NUM_BATCHES" in os.environ else "0") +DIR_DATA_TF = os.environ["DATA_DIR_TF"] +PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] +PATH_COLOURMAP = os.environ["COLOURMAP"] +STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None -DIR_OUTPUT=f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" +DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" if not os.path.exists(DIR_OUTPUT): os.makedirs(DIR_OUTPUT) -logger.info("DeepLabv3+ rainfall radar TEST") +logger.info("DeepLabV3+ rainfall radar TEST") logger.info(f"> NUM_BATCHES {NUM_BATCHES}") logger.info(f"> BATCH_SIZE {BATCH_SIZE}") logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}") -logger.info(f"> DL_PATH_HEIGHTMAP {DL_PATH_HEIGHTMAP}") +logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}") +logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}") logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}") +logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}") dataset_train, dataset_validate = dataset_mono( @@ -147,6 +150,7 @@ history = model.fit(train_dataset, separator="\t" ) ], + steps_per_epoch=STEPS_PER_EPOCH, ) logger.info(">>> Training complete") logger.info(">>> Plotting graphs") @@ -189,16 +193,16 @@ plt.close() # Loading the Colormap colormap = loadmat( - os.path.join(os.path.dirname(DATA_DIR), "human_colormap.mat") + PATH_COLOURMAP )["colormap"] colormap = colormap * 100 colormap = colormap.astype(np.uint8) def infer(model, image_tensor): - predictions = model.predict(np.expand_dims((image_tensor), axis=0)) - predictions = np.squeeze(predictions) - predictions = np.argmax(predictions, axis=2) + predictions = model.predict(tf.expand_dims((image_tensor), axis=0)) + predictions = tf.squeeze(predictions) + predictions = tf.argmax(predictions, axis=2) return predictions @@ -232,18 +236,36 @@ def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)): plt.savefig(filepath) -def plot_predictions(filepath, images_list, colormap, model): - for image_file in images_list: - image_tensor = read_image(image_file) - prediction_mask = infer(image_tensor=image_tensor, model=model) +def plot_predictions(filepath, input_items, colormap, model): + for input_tensor in input_items: + prediction_mask = infer(image_tensor=input_tensor, model=model) prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20) - overlay = get_overlay(image_tensor, prediction_colormap) + overlay = get_overlay(input_tensor, prediction_colormap) plot_samples_matplotlib( filepath, - [image_tensor, overlay, prediction_colormap], + [input_tensor, overlay, prediction_colormap], figsize=(18, 14) ) +def get_items_from_batched(dataset, count): + result = [] + for batched in dataset: + items = tf.unstack(batched, axis=0) + for item in items: + result.append(item) + if len(result) >= count: + return result -plot_predictions(os.path.join(DIR_OUTPUT, "predict_train.png"), train_images[:4], colormap, model=model) -plot_predictions(os.path.join(DIR_OUTPUT, "predict_validate.png"), val_images[:4], colormap, model=model) + +plot_predictions( + os.path.join(DIR_OUTPUT, "predict_train.png"), + get_items_from_batched(dataset_train, 4), + colormap, + model=model +) +plot_predictions( + os.path.join(DIR_OUTPUT, "predict_validate.png"), + get_items_from_batched(dataset_validate, 4), + colormap, + model=model +)