dlr: add UPSAMPLE env var

...AND actually add the functionality this time!
This commit is contained in:
Starbeamrainbowlabs 2023-05-04 17:40:16 +01:00
parent 31687da931
commit e2e6a56b40
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 16 additions and 5 deletions

View file

@ -6,7 +6,7 @@
#SBATCH -o %j.%N.%a.deeplab-rainfall.out.log
#SBATCH -e %j.%N.%a.deeplab-rainfall.err.log
#SBATCH -p gpu
#SBATCH --no-requeue
#SBATCH --no-requeue
#SBATCH --time=5-00:00:00
#SBATCH --mem=30000
# ---> in MiB
@ -41,6 +41,7 @@ show_help() {
echo -e " WATER_THRESHOLD The threshold to cut water off at when training, in metres. Default: 0.1" >&2;
echo -e " PATH_CHECKPOINT The path to a checkcpoint to load. If specified, a model will be loaded instead of being trained." >&2;
echo -e " LEARNING_RATE The learning rate to use. Default: 0.001." >&2;
echo -e " UPSAMPLE How much to upsample by at the beginning of the model. A value of disables upscaling. Default: 2." >&2;
echo -e " PREDICT_COUNT The number of items from the (SCRAMBLED) dataset to make a prediction for." >&2;
echo -e " POSTFIX Postfix to append to the output dir (auto calculated)." >&2;
echo -e " ARGS Optional. Any additional arguments to pass to the python program." >&2;

View file

@ -50,6 +50,7 @@ LOSS = os.environ["LOSS"] if "LOSS" in os.environ else "cross-entropy-dice"
DICE_LOG_COSH = True if "DICE_LOG_COSH" in os.environ else False
LEARNING_RATE = float(os.environ["LEARNING_RATE"]) if "LEARNING_RATE" in os.environ else 0.001
WATER_THRESHOLD = float(os.environ["WATER_THRESHOLD"]) if "WATER_THRESHOLD" in os.environ else 0.1
UPSAMPLE = int(os.environ["UPSAMPLE"]) if "UPSAMPLE" in os.environ else 2
DIR_OUTPUT=os.environ["DIR_OUTPUT"] if "DIR_OUTPUT" in os.environ else f"output/{datetime.utcnow().date().isoformat()}_deeplabv3plus_rainfall_TEST"
@ -135,9 +136,14 @@ if PATH_CHECKPOINT is None:
return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet"):
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet", upsample=2):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
if upsample > 1:
logger.info(f"[DeepLabV3+] Upsample enabled @ {upsample}x")
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
else:
logger.info(f"[DeepLabV3+] Upsample disabled")
x = model_input
match backbone:
case "resnet":
@ -168,8 +174,12 @@ if PATH_CHECKPOINT is None:
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return tf.keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
model = DeeplabV3Plus(
image_size=IMAGE_SIZE,
num_classes=NUM_CLASSES,
upsample=UPSAMPLE,
num_channels=8
)
summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
else:
model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={