mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 14:15:01 +00:00
dlr: address some ruff linting warnings
This commit is contained in:
parent
7e00ede747
commit
8c3ddbd86f
1 changed files with 23 additions and 22 deletions
|
@ -2,32 +2,32 @@
|
|||
# @source https://keras.io/examples/vision/deeplabv3_plus/
|
||||
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
|
||||
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
from lib.ai.helpers.summarywriter import summarywriter
|
||||
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||
|
||||
import os
|
||||
import io
|
||||
import math
|
||||
import json
|
||||
# import cv2 # optional import below in get_overlay
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from scipy.io import loadmat
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from loguru import logger
|
||||
from scipy.io import loadmat
|
||||
|
||||
# import cv2 # optional import below in get_overlay
|
||||
|
||||
import lib.primitives.env as env
|
||||
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
||||
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
|
||||
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
|
||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
||||
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
|
||||
from lib.ai.components.MetricSpecificity import specificity
|
||||
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
|
||||
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
|
||||
from lib.ai.helpers.summarywriter import summarywriter
|
||||
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
||||
|
||||
# from glob import glob
|
||||
|
||||
time_start = datetime.now()
|
||||
logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||
|
@ -166,7 +166,7 @@ if PATH_CHECKPOINT is None:
|
|||
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
|
||||
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
||||
else:
|
||||
logger.info(f"[DeepLabV3+] Upsample disabled")
|
||||
logger.info("[DeepLabV3+] Upsample disabled")
|
||||
x = model_input
|
||||
|
||||
match backbone:
|
||||
|
@ -279,6 +279,7 @@ if PATH_CHECKPOINT is None:
|
|||
),
|
||||
],
|
||||
steps_per_epoch=STEPS_PER_EPOCH,
|
||||
# use_multiprocessing=True # commented out but could be a good idea to squash warning? alt increase batch size..... but that uses more memory >_<
|
||||
)
|
||||
logger.info(">>> Training complete")
|
||||
logger.info(">>> Plotting graphs")
|
||||
|
@ -315,11 +316,11 @@ def decode_segmentation_masks(mask, colormap, n_classes):
|
|||
r = np.zeros_like(mask).astype(np.uint8)
|
||||
g = np.zeros_like(mask).astype(np.uint8)
|
||||
b = np.zeros_like(mask).astype(np.uint8)
|
||||
for l in range(0, n_classes):
|
||||
idx = mask == l
|
||||
r[idx] = colormap[l, 0]
|
||||
g[idx] = colormap[l, 1]
|
||||
b[idx] = colormap[l, 2]
|
||||
for label_index in range(0, n_classes):
|
||||
idx = mask == label_index
|
||||
r[idx] = colormap[label_index, 0]
|
||||
g[idx] = colormap[label_index, 1]
|
||||
b[idx] = colormap[label_index, 2]
|
||||
rgb = np.stack([r, g, b], axis=2)
|
||||
return rgb
|
||||
|
||||
|
|
Loading…
Reference in a new issue