mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-24 10:13:00 +00:00
Make dirpath_rainfallwater consistent with other experiments
This commit is contained in:
parent
e72d3991b8
commit
818d77c733
1 changed files with 5 additions and 4 deletions
|
@ -17,17 +17,18 @@ from lib.ai.helpers.summarywriter import summarywriter
|
|||
# ███████ ██ ████ ████
|
||||
|
||||
# TODO: env vars & settings here
|
||||
DIRPATH_INPUT = os.environ["DIRPATH_INPUT"]
|
||||
DIRPATH_RAINFALLWATER = os.environ["DIRPATH_RAINFALLWATER"]
|
||||
DIRPATH_OUTPUT = os.environ["DIRPATH_OUTPUT"]
|
||||
PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] if "PATH_HEIGHTMAP" in os.environ else None
|
||||
CHANNELS = os.environ["CHANNELS"] if "CHANNELS" in os.environ else 8
|
||||
|
||||
EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 25
|
||||
BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64
|
||||
WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ else 33
|
||||
STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None
|
||||
|
||||
logger.info("Encoder-only rainfall radar TEST")
|
||||
logger.info(f"> DIRPATH_INPUT {DIRPATH_INPUT}")
|
||||
logger.info(f"> DIRPATH_RAINFALLWATER {DIRPATH_RAINFALLWATER}")
|
||||
logger.info(f"> DIRPATH_OUTPUT {DIRPATH_OUTPUT}")
|
||||
logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}")
|
||||
logger.info(f"> CHANNELS {CHANNELS}")
|
||||
|
@ -47,7 +48,7 @@ if not os.path.exists(DIRPATH_OUTPUT):
|
|||
# ██████ ██ ██ ██ ██ ██ ███████ ███████ ██
|
||||
|
||||
dataset_train, dataset_validate = dataset_encoderonly(
|
||||
dirpath_input=DIRPATH_INPUT,
|
||||
dirpath_input=DIRPATH_RAINFALLWATER,
|
||||
filepath_heightmap=PATH_HEIGHTMAP,
|
||||
batch_size=BATCH_SIZE,
|
||||
windowsize=WINDOW_SIZE,
|
||||
|
@ -112,7 +113,7 @@ summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt"))
|
|||
|
||||
history = model.fit(dataset_train,
|
||||
validation_data=dataset_validate,
|
||||
epochs=25,
|
||||
epochs=EPOCHS,
|
||||
|
||||
callbacks=[
|
||||
tf.keras.callbacks.CSVLogger(
|
||||
|
|
Loading…
Reference in a new issue