From 8c3ddbd86fc78f40a604e215083a513fb726786b Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 19 Dec 2024 15:21:36 +0000 Subject: [PATCH] dlr: address some ruff linting warnings --- aimodel/src/deeplabv3_plus_test_rainfall.py | 45 +++++++++++---------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 5ba6640..560e20d 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -2,32 +2,32 @@ # @source https://keras.io/examples/vision/deeplabv3_plus/ # Required dataset: https://drive.google.com/uc?id=1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz [instance-level-human-parsing.zip] -from datetime import datetime -from loguru import logger - -from lib.ai.helpers.summarywriter import summarywriter -from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint - -import os import io -import math import json -# import cv2 # optional import below in get_overlay -import numpy as np -from glob import glob -from scipy.io import loadmat -import matplotlib.pyplot as plt +import math +import os +from datetime import datetime +import matplotlib.pyplot as plt +import numpy as np import tensorflow as tf +from loguru import logger +from scipy.io import loadmat + +# import cv2 # optional import below in get_overlay import lib.primitives.env as env -from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict +from lib.ai.components.CallbackCustomModelCheckpoint import CallbackCustomModelCheckpoint +from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient +from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou from lib.ai.components.MetricSensitivity import make_sensitivity as sensitivity from lib.ai.components.MetricSpecificity import specificity -from lib.ai.components.MetricMeanIoU import make_one_hot_mean_iou as mean_iou -from lib.ai.components.CallbackExtraValidation import CallbackExtraValidation +from lib.ai.helpers.summarywriter import summarywriter +from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict + +# from glob import glob time_start = datetime.now() logger.info(f"Starting at {str(datetime.now().isoformat())}") @@ -166,7 +166,7 @@ if PATH_CHECKPOINT is None: logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x") x = tf.keras.layers.UpSampling2D(size=2)(model_input) else: - logger.info(f"[DeepLabV3+] Upsample disabled") + logger.info("[DeepLabV3+] Upsample disabled") x = model_input match backbone: @@ -279,6 +279,7 @@ if PATH_CHECKPOINT is None: ), ], steps_per_epoch=STEPS_PER_EPOCH, + # use_multiprocessing=True # commented out but could be a good idea to squash warning? alt increase batch size..... but that uses more memory >_< ) logger.info(">>> Training complete") logger.info(">>> Plotting graphs") @@ -315,11 +316,11 @@ def decode_segmentation_masks(mask, colormap, n_classes): r = np.zeros_like(mask).astype(np.uint8) g = np.zeros_like(mask).astype(np.uint8) b = np.zeros_like(mask).astype(np.uint8) - for l in range(0, n_classes): - idx = mask == l - r[idx] = colormap[l, 0] - g[idx] = colormap[l, 1] - b[idx] = colormap[l, 2] + for label_index in range(0, n_classes): + idx = mask == label_index + r[idx] = colormap[label_index, 0] + g[idx] = colormap[label_index, 1] + b[idx] = colormap[label_index, 2] rgb = np.stack([r, g, b], axis=2) return rgb