From e565c3614941b92efe56df6636721ebf4b318c96 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 14 Mar 2023 21:51:41 +0000 Subject: [PATCH] deeplabv3+: prepare for ConvNeXt --- aimodel/src/deeplabv3_plus_test_rainfall.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index 9da620f..10912a9 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -135,21 +135,27 @@ if PATH_CHECKPOINT is None: return output - def DeeplabV3Plus(image_size, num_classes, num_channels=3): + def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet"): model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) x = tf.keras.layers.UpSampling2D(size=2)(model_input) - resnet50 = tf.keras.applications.ResNet50( - weights="imagenet" if num_channels == 3 else None, - include_top=False, input_tensor=x - ) - x = resnet50.get_layer("conv4_block6_2_relu").output + + match backbone: + case "resnet": + backbone = tf.keras.applications.ResNet50( + weights="imagenet" if num_channels == 3 else None, + include_top=False, input_tensor=x + ) + case _: + raise Exception(f"Error: Unknown backbone {backbone}") + + x = backbone.get_layer("conv4_block6_2_relu").output x = DilatedSpatialPyramidPooling(x) input_a = tf.keras.layers.UpSampling2D( size=(image_size // 4 // x.shape[1] * 2, image_size // 4 // x.shape[2] * 2), # <--- UPSAMPLE after pyramid interpolation="bilinear", )(x) - input_b = resnet50.get_layer("conv2_block3_2_relu").output + input_b = backbone.get_layer("conv2_block3_2_relu").output input_b = convolution_block(input_b, num_filters=48, kernel_size=1) x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])