This commit is contained in:
Starbeamrainbowlabs 2023-01-05 17:09:09 +00:00
parent 11ccd4cbee
commit dd79fb6e68
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -18,7 +18,7 @@ import tensorflow as tf
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
NUM_CLASSES = 2
DIR_DATA_TF = os.environ["DATA_DIR_TF"]
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"]
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
PATH_COLOURMAP = os.environ["COLOURMAP"]
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
@ -29,17 +29,16 @@ if not os.path.exists(DIR_OUTPUT):
os.makedirs(DIR_OUTPUT)
logger.info("DeepLabV3+ rainfall radar TEST")
logger.info(f"> NUM_BATCHES {NUM_BATCHES}")
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}")
logger.info(f"> DIR_RAINFALLWATER {DIR_RAINFALLWATER}")
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
dataset_train, dataset_validate = dataset_mono(
dirpath_input=DIR_DATA,
dirpath_input=DIR_RAINFALLWATER,
batch_size=BATCH_SIZE,
water_threshold=0.1,
rainfall_scale_up=2, # done BEFORE cropping to the below size
@ -141,8 +140,8 @@ model.compile(
metrics=["accuracy"],
)
logger.info(">>> Beginning training")
history = model.fit(train_dataset,
validation_data=val_dataset,
history = model.fit(dataset_train,
validation_data=dataset_validate,
epochs=25,
callbacks=[
tf.keras.callbacks.CSVLogger(
@ -219,10 +218,10 @@ def decode_segmentation_masks(mask, colormap, n_classes):
return rgb
def get_overlay(image, colored_mask):
def get_overlay(image, coloured_mask):
image = tf.keras.preprocessing.image.array_to_img(image)
image = np.array(image).astype(np.uint8)
overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0)
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
return overlay