From 5d62e3cee859f50691fa328aed851583e164d197 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 29 Aug 2024 16:43:29 +0100 Subject: [PATCH] Implement env from PhD-smflooding-scene --- aimodel/src/deeplabv3_plus_test_rainfall.py | 53 +++---- aimodel/src/lib/primitives/env.py | 145 ++++++++++++++++++++ 2 files changed, 172 insertions(+), 26 deletions(-) create mode 100644 aimodel/src/lib/primitives/env.py diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 1c91637..cff279d 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -20,6 +20,7 @@ import matplotlib.pyplot as plt import tensorflow as tf +import lib.primitives.env from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient @@ -37,41 +38,41 @@ logger.info(f"Starting at {str(datetime.now().isoformat())}") # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██ -IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data -BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64 +IMAGE_SIZE = env.read("IMAGE_SIZE", int, 128) # was 512; 128 is the highest power of 2 that fits the data +BATCH_SIZE = env.read("BATCH_SIZE", int, 64) NUM_CLASSES = 2 -DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"] -PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] -PATH_COLOURMAP = os.environ["PATH_COLOURMAP"] -PARALLEL_READS = float(os.environ["PARALLEL_READS"]) if "PARALLEL_READS" in os.environ else 1.5 -STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None -REMOVE_ISOLATED_PIXELS = False if "NO_REMOVE_ISOLATED_PIXELS" in os.environ else True -EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 50 -LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice" # other possible valuesL cross-entropy -DICE_LOG_COSH = True if "DICE_LOG_COSH" in os.environ else False -LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001 -WATER_THRESHOLD = float(os.environ["WATER_THRESHOLD"]) if "WATER_THRESHOLD" in os.environ else 0.1 -UPSAMPLE = int(os.environ["UPSAMPLE"]) if "UPSAMPLE" in os.environ else 2 +DIR_RAINFALLWATER = env.read("DIR_RAINFALLWATER", str) +PATH_HEIGHTMAP = env.read("PATH_HEIGHTMAP", str) +PATH_COLOURMAP = env.read("PATH_COLOURMAP", str) +PARALLEL_READS = env.read("PARALLEL_READS", float, 1.5) +STEPS_PER_EPOCH = env.read("STEPS_PER_EPOCH", int, None) +REMOVE_ISOLATED_PIXELS = env.read("NO_REMOVE_ISOLATED_PIXELS", bool, True) +EPOCHS = env.read("EPOCHS", int, 50) +LOSS = env.read("LOSS", str, "cross-entropy-dice") # other possible values: cross-entropy +DICE_LOG_COSH = env.read("DICE_LOG_COSH", bool, False) +LEARNING_RATE = env.read("LEARNING_RATE", float, 0.001) +WATER_THRESHOLD = env.read("WATER_THRESHOLD", float, 0.1) +UPSAMPLE = env.read("UPSAMPLE", int, 2) +SPLIT_VALIDATE = env.read("SPLIT_VALIDATE", float, 0.2) +SPLIT_TEST = env.read("SPLIT_TEST", float, 0) -STEPS_PER_EXECUTION = int(os.environ["STEPS_PER_EXECUTION"]) if "STEPS_PER_EXECUTION" in os.environ else 1 -JIT_COMPILE = True if "JIT_COMPILE" in os.environ else False -DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" - -PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None -PREDICT_COUNT = int(os.environ["PREDICT_COUNT"]) if "PREDICT_COUNT" in os.environ else 25 -PREDICT_AS_ONE = True if "PREDICT_AS_ONE" in os.environ else False - +STEPS_PER_EXECUTION = env.read("STEPS_PER_EXECUTION", int, 1) +JIT_COMPILE = env.read("JIT_COMPILE", bool, False) +DIR_OUTPUT = env.read("DIR_OUTPUT", str, f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST") +PATH_CHECKPOINT = env.read("PATH_CHECKPOINT", str, None) +PREDICT_COUNT = env.read("PREDICT_COUNT", int, 25) +PREDICT_AS_ONE = env.read("PREDICT_AS_ONE", bool, False) # ~~~ -if not os.path.exists(DIR_OUTPUT): - os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints")) +env.val_dir_exists(os.path.join(DIR_OUTPUT, "checkpoints"), create=True) # ~~~ logger.info("DeepLabV3+ rainfall radar TEST") -for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]: - logger.info(f"> {env_name} {str(globals()[env_name])}") +env.print_all(False) +# for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]: +# logger.info(f"> {env_name} {str(globals()[env_name])}") # ██████ █████ ████████ █████ ███████ ███████ ████████ diff --git a/aimodel/src/lib/primitives/env.py b/aimodel/src/lib/primitives/env.py new file mode 100644 index 0000000..888cf57 --- /dev/null +++ b/aimodel/src/lib/primitives/env.py @@ -0,0 +1,145 @@ +import os + +# Ref https://stackoverflow.com/a/61733714/1460422 + +### +## Environment parsing and validation helpers +## @sbrl, Licence: GPLv3 +### + +## Changelog: +# 2024-09-29: Create this changelog, prepare for reuse + +############################################################################## + +# Simple polyfill for Symbol from JS: https://devdocs.io/javascript/global_objects/symbol + +class Symbol: + def __init__(self, name=''): + self.name = f"Symbol({name})" + + def __repr__(self): + return self.name + +############################################################################## + + +SYM_RAISE_EXCEPTION = Symbol("__env_read_raise_exception") + +envs_read = [] + + +def read(name, type_class, default=SYM_RAISE_EXCEPTION): + """ + Reads, parses, and returns an environment variable with the specified name and type, with an optional default value. + + If the environment variable does not exist and no default value is provided, an `Exception` is raised. Otherwise, the environment variable value is converted to the specified type and returned. + + If `type_class == bool` and no `default` value is provided, then the default value is set to `False` and an `Exception` is **not** raised. + + Args: + name (str): The name of the environment variable to read. + type_class (type): The type to convert the environment variable value to. + default (Any, optional): The default value to use if the environment variable does not exist. Defaults to `SYM_RAISE_EXCEPTION`, which will raise an exception if the variable does not exist. + + Returns: + Any: The environment variable value converted to the specified type. + + Raises: + Exception: If the environment variable does not exist and no default value is provided. + """ + + if name not in os.environ: + if type_class == bool and default == SYM_RAISE_EXCEPTION: + default = False + if default == SYM_RAISE_EXCEPTION: + raise Exception(f"Error: Environment variable {name} does not exist") + envs_read.append([name, default, True]) + return default + + result = os.environ[name] + if type_class == bool: + result = False if default == True else True + else: + result = type_class(result) + + envs_read.append([name, result, False]) + return result + + +def print_all(table=True): + """ + Prints a formatted table of all environment variables that have been read so far. + + The table includes the name, type, value, and flags of each environment variable. The column widths are automatically adjusted to fit the longest values. + + If no environment variables have been read yet, a message is printed indicating that. + + Args: + table (bool): Whether to print in a table or not. Defaults to True. If False, then values are directly printed in a list instead of in a pretty table. + """ + + if not envs_read: + print("No environment variables have been read yet.") + return + + + # Calculate column widths + width_name = max(len("Name"), max(len(env[0]) for env in envs_read)) + width_type = max(len("Type"), max( + len(type(env[1]).__name__) for env in envs_read)) + width_value = max(len("Value"), max(len(str(env[1])) for env in envs_read)) + width_flags = max(len("Flags"), len("default")) + + + + if not table: + print("===================================") + for env in envs_read: + key, value, is_default = env + prefix = "* " if is_default else "" + print(f"> {key.ljust(width_name)} {value}") + print(f"Total {len(envs_read)} values") + print("===================================") + + + # Create the table format string + format_string = f"| {{:<{width_name}}} | {{:<{ + width_type}}} | {{:<{width_value}}} | {{:<{width_flags}}} |" + + # Calculate total width + total_width = width_name + width_type + width_value + \ + width_flags + 13 # 13 accounts for separators and spaces + + # Print the header + print("+" + "-" * (total_width - 2) + "+") + print(format_string.format("Name", "Type", "Value", "Flags")) + print("+" + "=" * (width_name + 2) + "+" + "=" * (width_type + 2) + + "+" + "=" * (width_value + 2) + "+" + "=" * (width_flags + 2) + "+") + + # Print each environment variable + for name, value, is_default in envs_read: + flags = "default" if is_default else "" + print(format_string.format(name, type(value).__name__, str(value), flags)) + print("+" + "-" * (width_name + 2) + "+" + "-" * (width_type + 2) + + "+" + "-" * (width_value + 2) + "+" + "-" * (width_flags + 2) + "+") + + +def val_exists(value, msg_error="The file or directory () does not exist, or I don't have permission to read it"): + if not os.path.exists(value): + raise Exception(msg_error.replace("()", f"'{value}'")) + + +def val_file_exists(value, msg_error="The file () does not exist, or I don't have permission to read it"): + if not os.path.isfile(value): + raise Exception(msg_error.replace("()", f"'{value}'")) + + +def val_dir_exists(value, msg_error="The directory () does not exist, or I don't have permission to read it", create=False): + if not os.path.isdir(value): + if create: + if os.path.exists(value): + raise Exception(f"Attempted to create directory '{value}', but it already exists and is not a directory") + os.makedirs(value) + return + raise Exception(msg_error.replace("()", f"'{value}'"))