mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 14:15:01 +00:00
dlr: add UPSAMPLE env var
...AND actually add the functionality this time!
This commit is contained in:
parent
31687da931
commit
e2e6a56b40
2 changed files with 16 additions and 5 deletions
|
@ -6,7 +6,7 @@
|
|||
#SBATCH -o %j.%N.%a.deeplab-rainfall.out.log
|
||||
#SBATCH -e %j.%N.%a.deeplab-rainfall.err.log
|
||||
#SBATCH -p gpu
|
||||
#SBATCH --no-requeue
|
||||
#SBATCH --no-requeue
|
||||
#SBATCH --time=5-00:00:00
|
||||
#SBATCH --mem=30000
|
||||
# ---> in MiB
|
||||
|
@ -41,6 +41,7 @@ show_help() {
|
|||
echo -e " WATER_THRESHOLD The threshold to cut water off at when training, in metres. Default: 0.1" >&2;
|
||||
echo -e " PATH_CHECKPOINT The path to a checkcpoint to load. If specified, a model will be loaded instead of being trained." >&2;
|
||||
echo -e " LEARNING_RATE The learning rate to use. Default: 0.001." >&2;
|
||||
echo -e " UPSAMPLE How much to upsample by at the beginning of the model. A value of disables upscaling. Default: 2." >&2;
|
||||
echo -e " PREDICT_COUNT The number of items from the (SCRAMBLED) dataset to make a prediction for." >&2;
|
||||
echo -e " POSTFIX Postfix to append to the output dir (auto calculated)." >&2;
|
||||
echo -e " ARGS Optional. Any additional arguments to pass to the python program." >&2;
|
||||
|
|
|
@ -50,6 +50,7 @@ LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice"
|
|||
DICE_LOG_COSH = True if "DICE_LOG_COSH" in os.environ else False
|
||||
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
|
||||
WATER_THRESHOLD = float(os.environ["WATER_THRESHOLD"]) if "WATER_THRESHOLD" in os.environ else 0.1
|
||||
UPSAMPLE = int(os.environ["UPSAMPLE"]) if "UPSAMPLE" in os.environ else 2
|
||||
|
||||
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
|
||||
|
||||
|
@ -135,9 +136,14 @@ if PATH_CHECKPOINT is None:
|
|||
return output
|
||||
|
||||
|
||||
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet"):
|
||||
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet", upsample=2):
|
||||
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
||||
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
||||
if upsample > 1:
|
||||
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
|
||||
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
||||
else:
|
||||
logger.info(f"[DeepLabV3+] Upsample disabled")
|
||||
x = model_input
|
||||
|
||||
match backbone:
|
||||
case "resnet":
|
||||
|
@ -168,8 +174,12 @@ if PATH_CHECKPOINT is None:
|
|||
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
||||
return tf.keras.Model(inputs=model_input, outputs=model_output)
|
||||
|
||||
|
||||
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
|
||||
model = DeeplabV3Plus(
|
||||
image_size=IMAGE_SIZE,
|
||||
num_classes=NUM_CLASSES,
|
||||
upsample=UPSAMPLE,
|
||||
num_channels=8
|
||||
)
|
||||
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
|
||||
else:
|
||||
model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={
|
||||
|
|
Loading…
Reference in a new issue