mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-21 17:03:00 +00:00
Implement env from PhD-smflooding-scene
This commit is contained in:
parent
a75d4f5d79
commit
5d62e3cee8
2 changed files with 172 additions and 26 deletions
|
@ -20,6 +20,7 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import lib.primitives.env
|
||||||
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
||||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||||
|
@ -37,41 +38,41 @@ logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
|
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
|
||||||
|
|
||||||
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
|
IMAGE_SIZE = env.read("IMAGE_SIZE", int, 128) # was 512; 128 is the highest power of 2 that fits the data
|
||||||
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
BATCH_SIZE = env.read("BATCH_SIZE", int, 64)
|
||||||
NUM_CLASSES = 2
|
NUM_CLASSES = 2
|
||||||
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"]
|
DIR_RAINFALLWATER = env.read("DIR_RAINFALLWATER", str)
|
||||||
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
|
PATH_HEIGHTMAP = env.read("PATH_HEIGHTMAP", str)
|
||||||
PATH_COLOURMAP = os.environ["PATH_COLOURMAP"]
|
PATH_COLOURMAP = env.read("PATH_COLOURMAP", str)
|
||||||
PARALLEL_READS = float(os.environ["PARALLEL_READS"]) if "PARALLEL_READS" in os.environ else 1.5
|
PARALLEL_READS = env.read("PARALLEL_READS", float, 1.5)
|
||||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
STEPS_PER_EPOCH = env.read("STEPS_PER_EPOCH", int, None)
|
||||||
REMOVE_ISOLATED_PIXELS = False if "NO_REMOVE_ISOLATED_PIXELS" in os.environ else True
|
REMOVE_ISOLATED_PIXELS = env.read("NO_REMOVE_ISOLATED_PIXELS", bool, True)
|
||||||
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 50
|
EPOCHS = env.read("EPOCHS", int, 50)
|
||||||
LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice" # other possible valuesL cross-entropy
|
LOSS = env.read("LOSS", str, "cross-entropy-dice") # other possible values: cross-entropy
|
||||||
DICE_LOG_COSH = True if "DICE_LOG_COSH" in os.environ else False
|
DICE_LOG_COSH = env.read("DICE_LOG_COSH", bool, False)
|
||||||
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
|
LEARNING_RATE = env.read("LEARNING_RATE", float, 0.001)
|
||||||
WATER_THRESHOLD = float(os.environ["WATER_THRESHOLD"]) if "WATER_THRESHOLD" in os.environ else 0.1
|
WATER_THRESHOLD = env.read("WATER_THRESHOLD", float, 0.1)
|
||||||
UPSAMPLE = int(os.environ["UPSAMPLE"]) if "UPSAMPLE" in os.environ else 2
|
UPSAMPLE = env.read("UPSAMPLE", int, 2)
|
||||||
|
SPLIT_VALIDATE = env.read("SPLIT_VALIDATE", float, 0.2)
|
||||||
|
SPLIT_TEST = env.read("SPLIT_TEST", float, 0)
|
||||||
|
|
||||||
|
|
||||||
STEPS_PER_EXECUTION = int(os.environ["STEPS_PER_EXECUTION"]) if "STEPS_PER_EXECUTION" in os.environ else 1
|
STEPS_PER_EXECUTION = env.read("STEPS_PER_EXECUTION", int, 1)
|
||||||
JIT_COMPILE = True if "JIT_COMPILE" in os.environ else False
|
JIT_COMPILE = env.read("JIT_COMPILE", bool, False)
|
||||||
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
DIR_OUTPUT = env.read("DIR_OUTPUT", str, f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST")
|
||||||
|
PATH_CHECKPOINT = env.read("PATH_CHECKPOINT", str, None)
|
||||||
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
|
PREDICT_COUNT = env.read("PREDICT_COUNT", int, 25)
|
||||||
PREDICT_COUNT = int(os.environ["PREDICT_COUNT"]) if "PREDICT_COUNT" in os.environ else 25
|
PREDICT_AS_ONE = env.read("PREDICT_AS_ONE", bool, False)
|
||||||
PREDICT_AS_ONE = True if "PREDICT_AS_ONE" in os.environ else False
|
|
||||||
|
|
||||||
# ~~~
|
# ~~~
|
||||||
|
|
||||||
if not os.path.exists(DIR_OUTPUT):
|
env.val_dir_exists(os.path.join(DIR_OUTPUT, "checkpoints"), create=True)
|
||||||
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
|
|
||||||
|
|
||||||
# ~~~
|
# ~~~
|
||||||
|
|
||||||
logger.info("DeepLabV3+ rainfall radar TEST")
|
logger.info("DeepLabV3+ rainfall radar TEST")
|
||||||
for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
|
env.print_all(False)
|
||||||
logger.info(f"> {env_name} {str(globals()[env_name])}")
|
# for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
|
||||||
|
# logger.info(f"> {env_name} {str(globals()[env_name])}")
|
||||||
|
|
||||||
|
|
||||||
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
||||||
|
|
145
aimodel/src/lib/primitives/env.py
Normal file
145
aimodel/src/lib/primitives/env.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Ref https://stackoverflow.com/a/61733714/1460422
|
||||||
|
|
||||||
|
###
|
||||||
|
## Environment parsing and validation helpers
|
||||||
|
## @sbrl, Licence: GPLv3
|
||||||
|
###
|
||||||
|
|
||||||
|
## Changelog:
|
||||||
|
# 2024-09-29: Create this changelog, prepare for reuse
|
||||||
|
|
||||||
|
##############################################################################
|
||||||
|
|
||||||
|
# Simple polyfill for Symbol from JS: https://devdocs.io/javascript/global_objects/symbol
|
||||||
|
|
||||||
|
class Symbol:
|
||||||
|
def __init__(self, name=''):
|
||||||
|
self.name = f"Symbol({name})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
##############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
SYM_RAISE_EXCEPTION = Symbol("__env_read_raise_exception")
|
||||||
|
|
||||||
|
envs_read = []
|
||||||
|
|
||||||
|
|
||||||
|
def read(name, type_class, default=SYM_RAISE_EXCEPTION):
|
||||||
|
"""
|
||||||
|
Reads, parses, and returns an environment variable with the specified name and type, with an optional default value.
|
||||||
|
|
||||||
|
If the environment variable does not exist and no default value is provided, an `Exception` is raised. Otherwise, the environment variable value is converted to the specified type and returned.
|
||||||
|
|
||||||
|
If `type_class == bool` and no `default` value is provided, then the default value is set to `False` and an `Exception` is **not** raised.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the environment variable to read.
|
||||||
|
type_class (type): The type to convert the environment variable value to.
|
||||||
|
default (Any, optional): The default value to use if the environment variable does not exist. Defaults to `SYM_RAISE_EXCEPTION`, which will raise an exception if the variable does not exist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The environment variable value converted to the specified type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the environment variable does not exist and no default value is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if name not in os.environ:
|
||||||
|
if type_class == bool and default == SYM_RAISE_EXCEPTION:
|
||||||
|
default = False
|
||||||
|
if default == SYM_RAISE_EXCEPTION:
|
||||||
|
raise Exception(f"Error: Environment variable {name} does not exist")
|
||||||
|
envs_read.append([name, default, True])
|
||||||
|
return default
|
||||||
|
|
||||||
|
result = os.environ[name]
|
||||||
|
if type_class == bool:
|
||||||
|
result = False if default == True else True
|
||||||
|
else:
|
||||||
|
result = type_class(result)
|
||||||
|
|
||||||
|
envs_read.append([name, result, False])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def print_all(table=True):
|
||||||
|
"""
|
||||||
|
Prints a formatted table of all environment variables that have been read so far.
|
||||||
|
|
||||||
|
The table includes the name, type, value, and flags of each environment variable. The column widths are automatically adjusted to fit the longest values.
|
||||||
|
|
||||||
|
If no environment variables have been read yet, a message is printed indicating that.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table (bool): Whether to print in a table or not. Defaults to True. If False, then values are directly printed in a list instead of in a pretty table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not envs_read:
|
||||||
|
print("No environment variables have been read yet.")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate column widths
|
||||||
|
width_name = max(len("Name"), max(len(env[0]) for env in envs_read))
|
||||||
|
width_type = max(len("Type"), max(
|
||||||
|
len(type(env[1]).__name__) for env in envs_read))
|
||||||
|
width_value = max(len("Value"), max(len(str(env[1])) for env in envs_read))
|
||||||
|
width_flags = max(len("Flags"), len("default"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if not table:
|
||||||
|
print("===================================")
|
||||||
|
for env in envs_read:
|
||||||
|
key, value, is_default = env
|
||||||
|
prefix = "* " if is_default else ""
|
||||||
|
print(f"> {key.ljust(width_name)} {value}")
|
||||||
|
print(f"Total {len(envs_read)} values")
|
||||||
|
print("===================================")
|
||||||
|
|
||||||
|
|
||||||
|
# Create the table format string
|
||||||
|
format_string = f"| {{:<{width_name}}} | {{:<{
|
||||||
|
width_type}}} | {{:<{width_value}}} | {{:<{width_flags}}} |"
|
||||||
|
|
||||||
|
# Calculate total width
|
||||||
|
total_width = width_name + width_type + width_value + \
|
||||||
|
width_flags + 13 # 13 accounts for separators and spaces
|
||||||
|
|
||||||
|
# Print the header
|
||||||
|
print("+" + "-" * (total_width - 2) + "+")
|
||||||
|
print(format_string.format("Name", "Type", "Value", "Flags"))
|
||||||
|
print("+" + "=" * (width_name + 2) + "+" + "=" * (width_type + 2) +
|
||||||
|
"+" + "=" * (width_value + 2) + "+" + "=" * (width_flags + 2) + "+")
|
||||||
|
|
||||||
|
# Print each environment variable
|
||||||
|
for name, value, is_default in envs_read:
|
||||||
|
flags = "default" if is_default else ""
|
||||||
|
print(format_string.format(name, type(value).__name__, str(value), flags))
|
||||||
|
print("+" + "-" * (width_name + 2) + "+" + "-" * (width_type + 2) +
|
||||||
|
"+" + "-" * (width_value + 2) + "+" + "-" * (width_flags + 2) + "+")
|
||||||
|
|
||||||
|
|
||||||
|
def val_exists(value, msg_error="The file or directory () does not exist, or I don't have permission to read it"):
|
||||||
|
if not os.path.exists(value):
|
||||||
|
raise Exception(msg_error.replace("()", f"'{value}'"))
|
||||||
|
|
||||||
|
|
||||||
|
def val_file_exists(value, msg_error="The file () does not exist, or I don't have permission to read it"):
|
||||||
|
if not os.path.isfile(value):
|
||||||
|
raise Exception(msg_error.replace("()", f"'{value}'"))
|
||||||
|
|
||||||
|
|
||||||
|
def val_dir_exists(value, msg_error="The directory () does not exist, or I don't have permission to read it", create=False):
|
||||||
|
if not os.path.isdir(value):
|
||||||
|
if create:
|
||||||
|
if os.path.exists(value):
|
||||||
|
raise Exception(f"Attempted to create directory '{value}', but it already exists and is not a directory")
|
||||||
|
os.makedirs(value)
|
||||||
|
return
|
||||||
|
raise Exception(msg_error.replace("()", f"'{value}'"))
|
Loading…
Reference in a new issue