mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-24 18:23:01 +00:00
Make dirpath_rainfallwater consistent with other experiments
This commit is contained in:
parent
e72d3991b8
commit
818d77c733
1 changed files with 5 additions and 4 deletions
|
@ -17,17 +17,18 @@ from lib.ai.helpers.summarywriter import summarywriter
|
||||||
# ███████ ██ ████ ████
|
# ███████ ██ ████ ████
|
||||||
|
|
||||||
# TODO: env vars & settings here
|
# TODO: env vars & settings here
|
||||||
DIRPATH_INPUT = os.environ["DIRPATH_INPUT"]
|
DIRPATH_RAINFALLWATER = os.environ["DIRPATH_RAINFALLWATER"]
|
||||||
DIRPATH_OUTPUT = os.environ["DIRPATH_OUTPUT"]
|
DIRPATH_OUTPUT = os.environ["DIRPATH_OUTPUT"]
|
||||||
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] if "PATH_HEIGHTMAP" in os.environ else None
|
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] if "PATH_HEIGHTMAP" in os.environ else None
|
||||||
CHANNELS = os.environ["CHANNELS"] if "CHANNELS" in os.environ else 8
|
CHANNELS = os.environ["CHANNELS"] if "CHANNELS" in os.environ else 8
|
||||||
|
|
||||||
|
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 25
|
||||||
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||||
WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ else 33
|
WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ else 33
|
||||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||||
|
|
||||||
logger.info("Encoder-only rainfall radar TEST")
|
logger.info("Encoder-only rainfall radar TEST")
|
||||||
logger.info(f"> DIRPATH_INPUT {DIRPATH_INPUT}")
|
logger.info(f"> DIRPATH_RAINFALLWATER {DIRPATH_RAINFALLWATER}")
|
||||||
logger.info(f"> DIRPATH_OUTPUT {DIRPATH_OUTPUT}")
|
logger.info(f"> DIRPATH_OUTPUT {DIRPATH_OUTPUT}")
|
||||||
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||||
logger.info(f"> CHANNELS {CHANNELS}")
|
logger.info(f"> CHANNELS {CHANNELS}")
|
||||||
|
@ -47,7 +48,7 @@ if not os.path.exists(DIRPATH_OUTPUT):
|
||||||
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
|
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
|
||||||
|
|
||||||
dataset_train, dataset_validate = dataset_encoderonly(
|
dataset_train, dataset_validate = dataset_encoderonly(
|
||||||
dirpath_input=DIRPATH_INPUT,
|
dirpath_input=DIRPATH_RAINFALLWATER,
|
||||||
filepath_heightmap=PATH_HEIGHTMAP,
|
filepath_heightmap=PATH_HEIGHTMAP,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
windowsize=WINDOW_SIZE,
|
windowsize=WINDOW_SIZE,
|
||||||
|
@ -112,7 +113,7 @@ summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt"))
|
||||||
|
|
||||||
history = model.fit(dataset_train,
|
history = model.fit(dataset_train,
|
||||||
validation_data=dataset_validate,
|
validation_data=dataset_validate,
|
||||||
epochs=25,
|
epochs=EPOCHS,
|
||||||
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
tf.keras.callbacks.CSVLogger(
|
tf.keras.callbacks.CSVLogger(
|
||||||
|
|
Loading…
Reference in a new issue