mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 14:15:01 +00:00
dlr: address some ruff linting warnings
This commit is contained in:
parent
7e00ede747
commit
8c3ddbd86f
1 changed files with 23 additions and 22 deletions
|
@ -2,32 +2,32 @@
|
||||||
# @source https://keras.io/examples/vision/deeplabv3_plus/
|
# @source https://keras.io/examples/vision/deeplabv3_plus/
|
||||||
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
|
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from lib.ai.helpers.summarywriter import summarywriter
|
|
||||||
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
|
||||||
|
|
||||||
import os
|
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import json
|
import json
|
||||||
# import cv2 # optional import below in get_overlay
|
import math
|
||||||
import numpy as np
|
import os
|
||||||
from glob import glob
|
from datetime import datetime
|
||||||
from scipy.io import loadmat
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from loguru import logger
|
||||||
|
from scipy.io import loadmat
|
||||||
|
|
||||||
|
# import cv2 # optional import below in get_overlay
|
||||||
|
|
||||||
import lib.primitives.env as env
|
import lib.primitives.env as env
|
||||||
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||||
|
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
|
||||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||||
|
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
||||||
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
||||||
from lib.ai.components.MetricSpecificity import specificity
|
from lib.ai.components.MetricSpecificity import specificity
|
||||||
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
from lib.ai.helpers.summarywriter import summarywriter
|
||||||
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
|
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
||||||
|
|
||||||
|
# from glob import glob
|
||||||
|
|
||||||
time_start = datetime.now()
|
time_start = datetime.now()
|
||||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||||
|
@ -166,7 +166,7 @@ if PATH_CHECKPOINT is None:
|
||||||
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
|
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
|
||||||
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
||||||
else:
|
else:
|
||||||
logger.info(f"[DeepLabV3+] Upsample disabled")
|
logger.info("[DeepLabV3+] Upsample disabled")
|
||||||
x = model_input
|
x = model_input
|
||||||
|
|
||||||
match backbone:
|
match backbone:
|
||||||
|
@ -279,6 +279,7 @@ if PATH_CHECKPOINT is None:
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
steps_per_epoch=STEPS_PER_EPOCH,
|
steps_per_epoch=STEPS_PER_EPOCH,
|
||||||
|
# use_multiprocessing=True # commented out but could be a good idea to squash warning? alt increase batch size..... but that uses more memory >_<
|
||||||
)
|
)
|
||||||
logger.info(">>> Training complete")
|
logger.info(">>> Training complete")
|
||||||
logger.info(">>> Plotting graphs")
|
logger.info(">>> Plotting graphs")
|
||||||
|
@ -315,11 +316,11 @@ def decode_segmentation_masks(mask, colormap, n_classes):
|
||||||
r = np.zeros_like(mask).astype(np.uint8)
|
r = np.zeros_like(mask).astype(np.uint8)
|
||||||
g = np.zeros_like(mask).astype(np.uint8)
|
g = np.zeros_like(mask).astype(np.uint8)
|
||||||
b = np.zeros_like(mask).astype(np.uint8)
|
b = np.zeros_like(mask).astype(np.uint8)
|
||||||
for l in range(0, n_classes):
|
for label_index in range(0, n_classes):
|
||||||
idx = mask == l
|
idx = mask == label_index
|
||||||
r[idx] = colormap[l, 0]
|
r[idx] = colormap[label_index, 0]
|
||||||
g[idx] = colormap[l, 1]
|
g[idx] = colormap[label_index, 1]
|
||||||
b[idx] = colormap[l, 2]
|
b[idx] = colormap[label_index, 2]
|
||||||
rgb = np.stack([r, g, b], axis=2)
|
rgb = np.stack([r, g, b], axis=2)
|
||||||
return rgb
|
return rgb
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue