mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 14:15:01 +00:00
fixup
This commit is contained in:
parent
11ccd4cbee
commit
dd79fb6e68
1 changed files with 8 additions and 9 deletions
|
@ -18,7 +18,7 @@ import tensorflow as tf
|
|||
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
|
||||
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||
NUM_CLASSES = 2
|
||||
DIR_DATA_TF = os.environ["DATA_DIR_TF"]
|
||||
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"]
|
||||
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
|
||||
PATH_COLOURMAP = os.environ["COLOURMAP"]
|
||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||
|
@ -29,17 +29,16 @@ if not os.path.exists(DIR_OUTPUT):
|
|||
os.makedirs(DIR_OUTPUT)
|
||||
|
||||
logger.info("DeepLabV3+ rainfall radar TEST")
|
||||
logger.info(f"> NUM_BATCHES {NUM_BATCHES}")
|
||||
logger.info(f"> BATCH_SIZE {BATCH_SIZE}")
|
||||
logger.info(f"> DIR_DATA_TF {DIR_DATA_TF}")
|
||||
logger.info(f"> DIR_RAINFALLWATER {DIR_RAINFALLWATER}")
|
||||
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||
logger.info(f"> PATH_COLOURMAP {PATH_COLOURMAP}")
|
||||
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
||||
logger.info(f"> STEPS_PER_EPOCH {STEPS_PER_EPOCH}")
|
||||
logger.info(f"> DIR_OUTPUT {DIR_OUTPUT}")
|
||||
|
||||
|
||||
dataset_train, dataset_validate = dataset_mono(
|
||||
dirpath_input=DIR_DATA,
|
||||
dirpath_input=DIR_RAINFALLWATER,
|
||||
batch_size=BATCH_SIZE,
|
||||
water_threshold=0.1,
|
||||
rainfall_scale_up=2, # done BEFORE cropping to the below size
|
||||
|
@ -141,8 +140,8 @@ model.compile(
|
|||
metrics=["accuracy"],
|
||||
)
|
||||
logger.info(">>> Beginning training")
|
||||
history = model.fit(train_dataset,
|
||||
validation_data=val_dataset,
|
||||
history = model.fit(dataset_train,
|
||||
validation_data=dataset_validate,
|
||||
epochs=25,
|
||||
callbacks=[
|
||||
tf.keras.callbacks.CSVLogger(
|
||||
|
@ -219,10 +218,10 @@ def decode_segmentation_masks(mask, colormap, n_classes):
|
|||
return rgb
|
||||
|
||||
|
||||
def get_overlay(image, colored_mask):
|
||||
def get_overlay(image, coloured_mask):
|
||||
image = tf.keras.preprocessing.image.array_to_img(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0)
|
||||
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
|
||||
return overlay
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue