mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-21 17:03:00 +00:00
Implement env from PhD-smflooding-scene
This commit is contained in:
parent
a75d4f5d79
commit
5d62e3cee8
2 changed files with 172 additions and 26 deletions
|
@ -20,6 +20,7 @@ import matplotlib.pyplot as plt
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
import lib.primitives.env
|
||||
from lib.dataset.dataset_mono import dataset_mono, dataset_mono_predict
|
||||
from lib.ai.components.LossCrossEntropyDice import LossCrossEntropyDice
|
||||
from lib.ai.components.MetricDice import metric_dice_coefficient as dice_coefficient
|
||||
|
@ -37,41 +38,41 @@ logger.info(f"Starting at {str(datetime.now().isoformat())}")
|
|||
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||
# ███████ ██ ████ ████ ██ ██ ██ ██████ ██ ████ ██ ██ ███████ ██ ████ ██
|
||||
|
||||
IMAGE_SIZE = int(os.environ["IMAGE_SIZE"]) if "IMAGE_SIZE" in os.environ else 128 # was 512; 128 is the highest power of 2 that fits the data
|
||||
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||
IMAGE_SIZE = env.read("IMAGE_SIZE", int, 128) # was 512; 128 is the highest power of 2 that fits the data
|
||||
BATCH_SIZE = env.read("BATCH_SIZE", int, 64)
|
||||
NUM_CLASSES = 2
|
||||
DIR_RAINFALLWATER = os.environ["DIR_RAINFALLWATER"]
|
||||
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"]
|
||||
PATH_COLOURMAP = os.environ["PATH_COLOURMAP"]
|
||||
PARALLEL_READS = float(os.environ["PARALLEL_READS"]) if "PARALLEL_READS" in os.environ else 1.5
|
||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||
REMOVE_ISOLATED_PIXELS = False if "NO_REMOVE_ISOLATED_PIXELS" in os.environ else True
|
||||
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 50
|
||||
LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice" # other possible valuesL cross-entropy
|
||||
DICE_LOG_COSH = True if "DICE_LOG_COSH" in os.environ else False
|
||||
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
|
||||
WATER_THRESHOLD = float(os.environ["WATER_THRESHOLD"]) if "WATER_THRESHOLD" in os.environ else 0.1
|
||||
UPSAMPLE = int(os.environ["UPSAMPLE"]) if "UPSAMPLE" in os.environ else 2
|
||||
DIR_RAINFALLWATER = env.read("DIR_RAINFALLWATER", str)
|
||||
PATH_HEIGHTMAP = env.read("PATH_HEIGHTMAP", str)
|
||||
PATH_COLOURMAP = env.read("PATH_COLOURMAP", str)
|
||||
PARALLEL_READS = env.read("PARALLEL_READS", float, 1.5)
|
||||
STEPS_PER_EPOCH = env.read("STEPS_PER_EPOCH", int, None)
|
||||
REMOVE_ISOLATED_PIXELS = env.read("NO_REMOVE_ISOLATED_PIXELS", bool, True)
|
||||
EPOCHS = env.read("EPOCHS", int, 50)
|
||||
LOSS = env.read("LOSS", str, "cross-entropy-dice") # other possible values: cross-entropy
|
||||
DICE_LOG_COSH = env.read("DICE_LOG_COSH", bool, False)
|
||||
LEARNING_RATE = env.read("LEARNING_RATE", float, 0.001)
|
||||
WATER_THRESHOLD = env.read("WATER_THRESHOLD", float, 0.1)
|
||||
UPSAMPLE = env.read("UPSAMPLE", int, 2)
|
||||
SPLIT_VALIDATE = env.read("SPLIT_VALIDATE", float, 0.2)
|
||||
SPLIT_TEST = env.read("SPLIT_TEST", float, 0)
|
||||
|
||||
|
||||
STEPS_PER_EXECUTION = int(os.environ["STEPS_PER_EXECUTION"]) if "STEPS_PER_EXECUTION" in os.environ else 1
|
||||
JIT_COMPILE = True if "JIT_COMPILE" in os.environ else False
|
||||
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||
|
||||
PATH_CHECKPOINT = os.environ["PATH_CHECKPOINT"] if "PATH_CHECKPOINT" in os.environ else None
|
||||
PREDICT_COUNT = int(os.environ["PREDICT_COUNT"]) if "PREDICT_COUNT" in os.environ else 25
|
||||
PREDICT_AS_ONE = True if "PREDICT_AS_ONE" in os.environ else False
|
||||
|
||||
STEPS_PER_EXECUTION = env.read("STEPS_PER_EXECUTION", int, 1)
|
||||
JIT_COMPILE = env.read("JIT_COMPILE", bool, False)
|
||||
DIR_OUTPUT = env.read("DIR_OUTPUT", str, f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST")
|
||||
PATH_CHECKPOINT = env.read("PATH_CHECKPOINT", str, None)
|
||||
PREDICT_COUNT = env.read("PREDICT_COUNT", int, 25)
|
||||
PREDICT_AS_ONE = env.read("PREDICT_AS_ONE", bool, False)
|
||||
# ~~~
|
||||
|
||||
if not os.path.exists(DIR_OUTPUT):
|
||||
os.makedirs(os.path.join(DIR_OUTPUT, "checkpoints"))
|
||||
env.val_dir_exists(os.path.join(DIR_OUTPUT, "checkpoints"), create=True)
|
||||
|
||||
# ~~~
|
||||
|
||||
logger.info("DeepLabV3+ rainfall radar TEST")
|
||||
for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
|
||||
logger.info(f"> {env_name} {str(globals()[env_name])}")
|
||||
env.print_all(False)
|
||||
# for env_name in [ "BATCH_SIZE","NUM_CLASSES", "DIR_RAINFALLWATER", "PATH_HEIGHTMAP", "PATH_COLOURMAP", "STEPS_PER_EPOCH", "PARALLEL_READS", "REMOVE_ISOLATED_PIXELS", "EPOCHS", "LOSS", "LEARNING_RATE", "DIR_OUTPUT", "PATH_CHECKPOINT", "PREDICT_COUNT", "DICE_LOG_COSH", "WATER_THRESHOLD", "UPSAMPLE", "STEPS_PER_EXECUTION", "JIT_COMPILE", "PREDICT_AS_ONE" ]:
|
||||
# logger.info(f"> {env_name} {str(globals()[env_name])}")
|
||||
|
||||
|
||||
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
||||
|
|
145
aimodel/src/lib/primitives/env.py
Normal file
145
aimodel/src/lib/primitives/env.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import os
|
||||
|
||||
# Ref https://stackoverflow.com/a/61733714/1460422
|
||||
|
||||
###
|
||||
## Environment parsing and validation helpers
|
||||
## @sbrl, Licence: GPLv3
|
||||
###
|
||||
|
||||
## Changelog:
|
||||
# 2024-09-29: Create this changelog, prepare for reuse
|
||||
|
||||
##############################################################################
|
||||
|
||||
# Simple polyfill for Symbol from JS: https://devdocs.io/javascript/global_objects/symbol
|
||||
|
||||
class Symbol:
|
||||
def __init__(self, name=''):
|
||||
self.name = f"Symbol({name})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
##############################################################################
|
||||
|
||||
|
||||
SYM_RAISE_EXCEPTION = Symbol("__env_read_raise_exception")
|
||||
|
||||
envs_read = []
|
||||
|
||||
|
||||
def read(name, type_class, default=SYM_RAISE_EXCEPTION):
|
||||
"""
|
||||
Reads, parses, and returns an environment variable with the specified name and type, with an optional default value.
|
||||
|
||||
If the environment variable does not exist and no default value is provided, an `Exception` is raised. Otherwise, the environment variable value is converted to the specified type and returned.
|
||||
|
||||
If `type_class == bool` and no `default` value is provided, then the default value is set to `False` and an `Exception` is **not** raised.
|
||||
|
||||
Args:
|
||||
name (str): The name of the environment variable to read.
|
||||
type_class (type): The type to convert the environment variable value to.
|
||||
default (Any, optional): The default value to use if the environment variable does not exist. Defaults to `SYM_RAISE_EXCEPTION`, which will raise an exception if the variable does not exist.
|
||||
|
||||
Returns:
|
||||
Any: The environment variable value converted to the specified type.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment variable does not exist and no default value is provided.
|
||||
"""
|
||||
|
||||
if name not in os.environ:
|
||||
if type_class == bool and default == SYM_RAISE_EXCEPTION:
|
||||
default = False
|
||||
if default == SYM_RAISE_EXCEPTION:
|
||||
raise Exception(f"Error: Environment variable {name} does not exist")
|
||||
envs_read.append([name, default, True])
|
||||
return default
|
||||
|
||||
result = os.environ[name]
|
||||
if type_class == bool:
|
||||
result = False if default == True else True
|
||||
else:
|
||||
result = type_class(result)
|
||||
|
||||
envs_read.append([name, result, False])
|
||||
return result
|
||||
|
||||
|
||||
def print_all(table=True):
|
||||
"""
|
||||
Prints a formatted table of all environment variables that have been read so far.
|
||||
|
||||
The table includes the name, type, value, and flags of each environment variable. The column widths are automatically adjusted to fit the longest values.
|
||||
|
||||
If no environment variables have been read yet, a message is printed indicating that.
|
||||
|
||||
Args:
|
||||
table (bool): Whether to print in a table or not. Defaults to True. If False, then values are directly printed in a list instead of in a pretty table.
|
||||
"""
|
||||
|
||||
if not envs_read:
|
||||
print("No environment variables have been read yet.")
|
||||
return
|
||||
|
||||
|
||||
# Calculate column widths
|
||||
width_name = max(len("Name"), max(len(env[0]) for env in envs_read))
|
||||
width_type = max(len("Type"), max(
|
||||
len(type(env[1]).__name__) for env in envs_read))
|
||||
width_value = max(len("Value"), max(len(str(env[1])) for env in envs_read))
|
||||
width_flags = max(len("Flags"), len("default"))
|
||||
|
||||
|
||||
|
||||
if not table:
|
||||
print("===================================")
|
||||
for env in envs_read:
|
||||
key, value, is_default = env
|
||||
prefix = "* " if is_default else ""
|
||||
print(f"> {key.ljust(width_name)} {value}")
|
||||
print(f"Total {len(envs_read)} values")
|
||||
print("===================================")
|
||||
|
||||
|
||||
# Create the table format string
|
||||
format_string = f"| {{:<{width_name}}} | {{:<{
|
||||
width_type}}} | {{:<{width_value}}} | {{:<{width_flags}}} |"
|
||||
|
||||
# Calculate total width
|
||||
total_width = width_name + width_type + width_value + \
|
||||
width_flags + 13 # 13 accounts for separators and spaces
|
||||
|
||||
# Print the header
|
||||
print("+" + "-" * (total_width - 2) + "+")
|
||||
print(format_string.format("Name", "Type", "Value", "Flags"))
|
||||
print("+" + "=" * (width_name + 2) + "+" + "=" * (width_type + 2) +
|
||||
"+" + "=" * (width_value + 2) + "+" + "=" * (width_flags + 2) + "+")
|
||||
|
||||
# Print each environment variable
|
||||
for name, value, is_default in envs_read:
|
||||
flags = "default" if is_default else ""
|
||||
print(format_string.format(name, type(value).__name__, str(value), flags))
|
||||
print("+" + "-" * (width_name + 2) + "+" + "-" * (width_type + 2) +
|
||||
"+" + "-" * (width_value + 2) + "+" + "-" * (width_flags + 2) + "+")
|
||||
|
||||
|
||||
def val_exists(value, msg_error="The file or directory () does not exist, or I don't have permission to read it"):
|
||||
if not os.path.exists(value):
|
||||
raise Exception(msg_error.replace("()", f"'{value}'"))
|
||||
|
||||
|
||||
def val_file_exists(value, msg_error="The file () does not exist, or I don't have permission to read it"):
|
||||
if not os.path.isfile(value):
|
||||
raise Exception(msg_error.replace("()", f"'{value}'"))
|
||||
|
||||
|
||||
def val_dir_exists(value, msg_error="The directory () does not exist, or I don't have permission to read it", create=False):
|
||||
if not os.path.isdir(value):
|
||||
if create:
|
||||
if os.path.exists(value):
|
||||
raise Exception(f"Attempted to create directory '{value}', but it already exists and is not a directory")
|
||||
os.makedirs(value)
|
||||
return
|
||||
raise Exception(msg_error.replace("()", f"'{value}'"))
|
Loading…
Reference in a new issue