dlr: address some ruff linting warnings

This commit is contained in:
Starbeamrainbowlabs 2024-12-19 15:21:36 +00:00
parent 7e00ede747
commit 8c3ddbd86f
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -2,32 +2,32 @@
# @source https://keras.io/examples/vision/deeplabv3_plus/
# Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip]
from datetime import datetime
from loguru import logger
from lib.ai.helpers.summarywriter import summarywriter
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
import os
import io
import math
import json
# import cv2 # optional import below in get_overlay
import numpy as np
from glob import glob
from scipy.io import loadmat
import matplotlib.pyplot as plt
import math
import os
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from loguru import logger
from scipy.io import loadmat
# import cv2 # optional import below in get_overlay
import lib.primitives.env as env
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity
from lib.ai.components.MetricSpecificity import specificity
from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou
from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation
from lib.ai.helpers.summarywriter import summarywriter
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
# from glob import glob
time_start = datetime.now()
logger.info(f"Starting at {str(datetime.now().isoformat())}")
@ -166,7 +166,7 @@ if PATH_CHECKPOINT is None:
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
else:
logger.info(f"[DeepLabV3+] Upsample disabled")
logger.info("[DeepLabV3+] Upsample disabled")
x = model_input
match backbone:
@ -279,6 +279,7 @@ if PATH_CHECKPOINT is None:
),
],
steps_per_epoch=STEPS_PER_EPOCH,
# use_multiprocessing=True # commented out but could be a good idea to squash warning? alt increase batch size..... but that uses more memory >_<
)
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
@ -315,11 +316,11 @@ def decode_segmentation_masks(mask, colormap, n_classes):
r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8)
for l in range(0, n_classes):
idx = mask == l
r[idx] = colormap[l, 0]
g[idx] = colormap[l, 1]
b[idx] = colormap[l, 2]
for label_index in range(0, n_classes):
idx = mask == label_index
r[idx] = colormap[label_index, 0]
g[idx] = colormap[label_index, 1]
b[idx] = colormap[label_index, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb