diff --git a/aimodel/src/encoderonly_test_rainfall.py b/aimodel/src/encoderonly_test_rainfall.py index 2925222..636111b 100755 --- a/aimodel/src/encoderonly_test_rainfall.py +++ b/aimodel/src/encoderonly_test_rainfall.py @@ -17,17 +17,18 @@ from lib.ai.helpers.summarywriter import summarywriter # ███████ ██ ████ ████ # TODO: env vars & settings here -DIRPATH_INPUT = os.environ["DIRPATH_INPUT"] +DIRPATH_RAINFALLWATER = os.environ["DIRPATH_RAINFALLWATER"] DIRPATH_OUTPUT = os.environ["DIRPATH_OUTPUT"] PATH_HEIGHTMAP = os.environ["PATH_HEIGHTMAP"] if "PATH_HEIGHTMAP" in os.environ else None CHANNELS = os.environ["CHANNELS"] if "CHANNELS" in os.environ else 8 +EPOCHS = int(os.environ["EPOCHS"]) if "EPOCHS" in os.environ else 25 BATCH_SIZE = int(os.environ["BATCH_SIZE"]) if "BATCH_SIZE" in os.environ else 64 WINDOW_SIZE = int(os.environ["WINDOW_SIZE"]) if "WINDOW_SIZE" in os.environ else 33 STEPS_PER_EPOCH = int(os.environ["STEPS_PER_EPOCH"]) if "STEPS_PER_EPOCH" in os.environ else None logger.info("Encoder-only rainfall radar TEST") -logger.info(f"> DIRPATH_INPUT {DIRPATH_INPUT}") +logger.info(f"> DIRPATH_RAINFALLWATER {DIRPATH_RAINFALLWATER}") logger.info(f"> DIRPATH_OUTPUT {DIRPATH_OUTPUT}") logger.info(f"> PATH_HEIGHTMAP {PATH_HEIGHTMAP}") logger.info(f"> CHANNELS {CHANNELS}") @@ -47,7 +48,7 @@ if not os.path.exists(DIRPATH_OUTPUT): # ██████ ██ ██ ██ ██ ██ ███████ ███████ ██ dataset_train, dataset_validate = dataset_encoderonly( - dirpath_input=DIRPATH_INPUT, + dirpath_input=DIRPATH_RAINFALLWATER, filepath_heightmap=PATH_HEIGHTMAP, batch_size=BATCH_SIZE, windowsize=WINDOW_SIZE, @@ -112,7 +113,7 @@ summarywriter(model, os.path.join(DIRPATH_OUTPUT, "summary.txt")) history = model.fit(dataset_train, validation_data=dataset_validate, - epochs=25, + epochs=EPOCHS, callbacks=[ tf.keras.callbacks.CSVLogger(