mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
dlr eo: add JIT_COMPILE and MIXED_PRECISION
This commit is contained in:
parent
71088b8c0b
commit
2bf1872aca
1 changed files with 10 additions and 3 deletions
|
@ -28,6 +28,8 @@ WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ e
|
||||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||||
STEPS_PER_EXECUTION = int(os.environ["STEPS_PER_EXECUTION"]) if "STEPS_PER_EXECUTION" in os.environ else None
|
STEPS_PER_EXECUTION = int(os.environ["STEPS_PER_EXECUTION"]) if "STEPS_PER_EXECUTION" in os.environ else None
|
||||||
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
|
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
|
||||||
|
JIT_COMPILE = True if "JIT_COMPILE" in os.environ else False
|
||||||
|
MIXED_PRECISION = True if "MIXED_PRECISION" in os.environ else False
|
||||||
|
|
||||||
logger.info("Encoder-only rainfall radar TEST")
|
logger.info("Encoder-only rainfall radar TEST")
|
||||||
logger.info(f"> DIRPATH_RAINFALLWATER {DIRPATH_RAINFALLWATER}")
|
logger.info(f"> DIRPATH_RAINFALLWATER {DIRPATH_RAINFALLWATER}")
|
||||||
|
@ -43,6 +45,9 @@ logger.info(f"> LEARNING_RATE {LEARNING_RATE}")
|
||||||
if not os.path.exists(DIRPATH_OUTPUT):
|
if not os.path.exists(DIRPATH_OUTPUT):
|
||||||
os.makedirs(os.path.join(DIRPATH_OUTPUT, "checkpoints"))
|
os.makedirs(os.path.join(DIRPATH_OUTPUT, "checkpoints"))
|
||||||
|
|
||||||
|
if MIXED_PRECISION:
|
||||||
|
tf.keras.mixed_precision.set_policy("mixed_float16")
|
||||||
|
|
||||||
|
|
||||||
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
# ██████ █████ ████████ █████ ███████ ███████ ████████
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
|
@ -65,7 +70,7 @@ dataset_train, dataset_validate = dataset_encoderonly(
|
||||||
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
# ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
# ██ ██ ██████ ██████ ███████ ███████
|
# ██ ██ ██████ ██████ ███████ ███████
|
||||||
|
|
||||||
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, steps_per_execution=1, **kwargs):
|
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, steps_per_execution=1, jit_compile=False, **kwargs):
|
||||||
if encoder == "convnext":
|
if encoder == "convnext":
|
||||||
model = make_convnext(
|
model = make_convnext(
|
||||||
input_shape=(windowsize, windowsize, channels),
|
input_shape=(windowsize, windowsize, channels),
|
||||||
|
@ -97,7 +102,8 @@ def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, ste
|
||||||
metrics = [
|
metrics = [
|
||||||
tf.keras.metrics.SparseCategoricalAccuracy()
|
tf.keras.metrics.SparseCategoricalAccuracy()
|
||||||
],
|
],
|
||||||
steps_per_execution=steps_per_execution
|
steps_per_execution=steps_per_execution,
|
||||||
|
jit_compile=jit_compile
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -106,7 +112,8 @@ def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, ste
|
||||||
model = make_encoderonly(
|
model = make_encoderonly(
|
||||||
windowsize=WINDOW_SIZE,
|
windowsize=WINDOW_SIZE,
|
||||||
channels=CHANNELS,
|
channels=CHANNELS,
|
||||||
steps_per_execution=STEPS_PER_EXECUTION
|
steps_per_execution=STEPS_PER_EXECUTION,
|
||||||
|
jit_compile=JIT_COMPILE
|
||||||
)
|
)
|
||||||
summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt"))
|
summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt"))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue