Implement env from PhD-smflooding-scene

This commit is contained in:
Starbeamrainbowlabs 2024-08-29 16:43:29 +01:00
parent a75d4f5d79
commit 5d62e3cee8
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 172 additions and 26 deletions

View file

@ -20,6 +20,7 @@ import matplotlib.pyplot as plt
import tensorflow as tf import tensorflow as tf
import lib.primitives.env
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
@ -37,41 +38,41 @@ logger.info(f"Starting at {str(datetime.now().isoformat())}")
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ # ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██ # ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data IMAGE_SIZE = env.read("IMAGE_SIZE", int, 128) # was 512; 128 is the highest power of 2 that fits the data
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64 BATCH_SIZE = env.read("BATCH_SIZE", int, 64)
NUM_CLASSES = 2 NUM_CLASSES = 2
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"] DIR_RAINFALLWATER = env.read("DIR_RAINFALLWATER", str)
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] PATH_HEIGHTMAP = env.read("PATH_HEIGHTMAP", str)
PATH_COLOURMAP = os.environ["PATH_COLOURMAP"] PATH_COLOURMAP = env.read("PATH_COLOURMAP", str)
PARALLEL_READS = float(os.environ["PARALLEL_READS"]) if "PARALLEL_READS" in os.environ else 1.5 PARALLEL_READS = env.read("PARALLEL_READS", float, 1.5)
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None STEPS_PER_EPOCH = env.read("STEPS_PER_EPOCH", int, None)
REMOVE_ISOLATED_PIXELS = False if "NO_REMOVE_ISOLATED_PIXELS" in os.environ else True REMOVE_ISOLATED_PIXELS = env.read("NO_REMOVE_ISOLATED_PIXELS", bool, True)
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 50 EPOCHS = env.read("EPOCHS", int, 50)
LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice" # other possible valuesL cross-entropy LOSS = env.read("LOSS", str, "cross-entropy-dice") # other possible values: cross-entropy
DICE_LOG_COSH = True if "DICE_LOG_COSH" in os.environ else False DICE_LOG_COSH = env.read("DICE_LOG_COSH", bool, False)
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001 LEARNING_RATE = env.read("LEARNING_RATE", float, 0.001)
WATER_THRESHOLD = float(os.environ["WATER_THRESHOLD"]) if "WATER_THRESHOLD" in os.environ else 0.1 WATER_THRESHOLD = env.read("WATER_THRESHOLD", float, 0.1)
UPSAMPLE = int(os.environ["UPSAMPLE"]) if "UPSAMPLE" in os.environ else 2 UPSAMPLE = env.read("UPSAMPLE", int, 2)
SPLIT_VALIDATE = env.read("SPLIT_VALIDATE", float, 0.2)
SPLIT_TEST = env.read("SPLIT_TEST", float, 0)
STEPS_PER_EXECUTION = int(os.environ["STEPS_PER_EXECUTION"]) if "STEPS_PER_EXECUTION" in os.environ else 1 STEPS_PER_EXECUTION = env.read("STEPS_PER_EXECUTION", int, 1)
JIT_COMPILE = True if "JIT_COMPILE" in os.environ else False JIT_COMPILE = env.read("JIT_COMPILE", bool, False)
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST" DIR_OUTPUT = env.read("DIR_OUTPUT", str, f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST")
PATH_CHECKPOINT = env.read("PATH_CHECKPOINT", str, None)
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None PREDICT_COUNT = env.read("PREDICT_COUNT", int, 25)
PREDICT_COUNT = int(os.environ["PREDICT_COUNT"]) if "PREDICT_COUNT" in os.environ else 25 PREDICT_AS_ONE = env.read("PREDICT_AS_ONE", bool, False)
PREDICT_AS_ONE = True if "PREDICT_AS_ONE" in os.environ else False
# ~~~ # ~~~
if not os.path.exists(DIR_OUTPUT): env.val_dir_exists(os.path.join(DIR_OUTPUT, "checkpoints"), create=True)
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
# ~~~ # ~~~
logger.info("DeepLabV3+ rainfall radar TEST") logger.info("DeepLabV3+ rainfall radar TEST")
for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]: env.print_all(False)
logger.info(f"> {env_name} {str(globals()[env_name])}") # for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
# logger.info(f"> {env_name} {str(globals()[env_name])}")
# ██████ █████ ████████ █████ ███████ ███████ ████████ # ██████ █████ ████████ █████ ███████ ███████ ████████

View file

@ -0,0 +1,145 @@
import os
# Ref https://stackoverflow.com/a/61733714/1460422
###
## Environment parsing and validation helpers
## @sbrl, Licence: GPLv3
###
## Changelog:
# 2024-09-29: Create this changelog, prepare for reuse
##############################################################################
# Simple polyfill for Symbol from JS: https://devdocs.io/javascript/global_objects/symbol
class Symbol:
def __init__(self, name=''):
self.name = f"Symbol({name})"
def __repr__(self):
return self.name
##############################################################################
SYM_RAISE_EXCEPTION = Symbol("__env_read_raise_exception")
envs_read = []
def read(name, type_class, default=SYM_RAISE_EXCEPTION):
"""
Reads, parses, and returns an environment variable with the specified name and type, with an optional default value.
If the environment variable does not exist and no default value is provided, an `Exception` is raised. Otherwise, the environment variable value is converted to the specified type and returned.
If `type_class == bool` and no `default` value is provided, then the default value is set to `False` and an `Exception` is **not** raised.
Args:
name (str): The name of the environment variable to read.
type_class (type): The type to convert the environment variable value to.
default (Any, optional): The default value to use if the environment variable does not exist. Defaults to `SYM_RAISE_EXCEPTION`, which will raise an exception if the variable does not exist.
Returns:
Any: The environment variable value converted to the specified type.
Raises:
Exception: If the environment variable does not exist and no default value is provided.
"""
if name not in os.environ:
if type_class == bool and default == SYM_RAISE_EXCEPTION:
default = False
if default == SYM_RAISE_EXCEPTION:
raise Exception(f"Error: Environment variable {name} does not exist")
envs_read.append([name, default, True])
return default
result = os.environ[name]
if type_class == bool:
result = False if default == True else True
else:
result = type_class(result)
envs_read.append([name, result, False])
return result
def print_all(table=True):
"""
Prints a formatted table of all environment variables that have been read so far.
The table includes the name, type, value, and flags of each environment variable. The column widths are automatically adjusted to fit the longest values.
If no environment variables have been read yet, a message is printed indicating that.
Args:
table (bool): Whether to print in a table or not. Defaults to True. If False, then values are directly printed in a list instead of in a pretty table.
"""
if not envs_read:
print("No environment variables have been read yet.")
return
# Calculate column widths
width_name = max(len("Name"), max(len(env[0]) for env in envs_read))
width_type = max(len("Type"), max(
len(type(env[1]).__name__) for env in envs_read))
width_value = max(len("Value"), max(len(str(env[1])) for env in envs_read))
width_flags = max(len("Flags"), len("default"))
if not table:
print("===================================")
for env in envs_read:
key, value, is_default = env
prefix = "* " if is_default else ""
print(f"> {key.ljust(width_name)} {value}")
print(f"Total {len(envs_read)} values")
print("===================================")
# Create the table format string
format_string = f"| {{:<{width_name}}} | {{:<{
width_type}}} | {{:<{width_value}}} | {{:<{width_flags}}} |"
# Calculate total width
total_width = width_name + width_type + width_value + \
width_flags + 13 # 13 accounts for separators and spaces
# Print the header
print("+" + "-" * (total_width - 2) + "+")
print(format_string.format("Name", "Type", "Value", "Flags"))
print("+" + "=" * (width_name + 2) + "+" + "=" * (width_type + 2) +
"+" + "=" * (width_value + 2) + "+" + "=" * (width_flags + 2) + "+")
# Print each environment variable
for name, value, is_default in envs_read:
flags = "default" if is_default else ""
print(format_string.format(name, type(value).__name__, str(value), flags))
print("+" + "-" * (width_name + 2) + "+" + "-" * (width_type + 2) +
"+" + "-" * (width_value + 2) + "+" + "-" * (width_flags + 2) + "+")
def val_exists(value, msg_error="The file or directory () does not exist, or I don't have permission to read it"):
if not os.path.exists(value):
raise Exception(msg_error.replace("()", f"'{value}'"))
def val_file_exists(value, msg_error="The file () does not exist, or I don't have permission to read it"):
if not os.path.isfile(value):
raise Exception(msg_error.replace("()", f"'{value}'"))
def val_dir_exists(value, msg_error="The directory () does not exist, or I don't have permission to read it", create=False):
if not os.path.isdir(value):
if create:
if os.path.exists(value):
raise Exception(f"Attempted to create directory '{value}', but it already exists and is not a directory")
os.makedirs(value)
return
raise Exception(msg_error.replace("()", f"'{value}'"))