dlr: address some ruff linting warnings

This commit is contained in:
Starbeamrainbowlabs 2024-12-19 15:21:36 +00:00
parent 7e00ede747
commit 8c3ddbd86f
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -2,32 +2,32 @@
# @source https://keras.io/examples/vision/deeplabv3_plus/ # @source https://keras.io/examples/vision/deeplabv3_plus/
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip] # Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
from datetime import datetime
from loguru import logger
from lib.ai.helpers.summarywriter import summarywriter
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
import os
import io import io
import math
import json import json
# import cv2 # optional import below in get_overlay import math
import numpy as np import os
from glob import glob from datetime import datetime
from scipy.io import loadmat
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf import tensorflow as tf
from loguru import logger
from scipy.io import loadmat
# import cv2 # optional import below in get_overlay
import lib.primitives.env as env import lib.primitives.env as env
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou from lib.ai.helpers.summarywriter import summarywriter
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
# from glob import glob
time_start = datetime.now() time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}") logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -166,7 +166,7 @@ if PATH_CHECKPOINT is None:
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x") logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
x = tf.keras.layers.UpSampling2D(size=2)(model_input) x = tf.keras.layers.UpSampling2D(size=2)(model_input)
else: else:
logger.info(f"[DeepLabV3+] Upsample disabled") logger.info("[DeepLabV3+] Upsample disabled")
x = model_input x = model_input
match backbone: match backbone:
@ -279,6 +279,7 @@ if PATH_CHECKPOINT is None:
), ),
], ],
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
# use_multiprocessing=True # commented out but could be a good idea to squash warning? alt increase batch size..... but that uses more memory >_<
) )
logger.info(">>> Training complete") logger.info(">>> Training complete")
logger.info(">>> Plotting graphs") logger.info(">>> Plotting graphs")
@ -315,11 +316,11 @@ def decode_segmentation_masks(mask, colormap, n_classes):
r = np.zeros_like(mask).astype(np.uint8) r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8) g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8) b = np.zeros_like(mask).astype(np.uint8)
for l in range(0, n_classes): for label_index in range(0, n_classes):
idx = mask == l idx = mask == label_index
r[idx] = colormap[l, 0] r[idx] = colormap[label_index, 0]
g[idx] = colormap[l, 1] g[idx] = colormap[label_index, 1]
b[idx] = colormap[l, 2] b[idx] = colormap[label_index, 2]
rgb = np.stack([r, g, b], axis=2) rgb = np.stack([r, g, b], axis=2)
return rgb return rgb