deeplabv3+: prepare for ConvNeXt

This commit is contained in:
Starbeamrainbowlabs 2023-03-14 21:51:41 +00:00
parent 779b546897
commit e565c36149
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -135,21 +135,27 @@ if PATH_CHECKPOINT is None:
return output return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3): def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet"):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
x = tf.keras.layers.UpSampling2D(size=2)(model_input) x = tf.keras.layers.UpSampling2D(size=2)(model_input)
resnet50 = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None, match backbone:
include_top=False, input_tensor=x case "resnet":
) backbone = tf.keras.applications.ResNet50(
x = resnet50.get_layer("conv4_block6_2_relu").output weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=x
)
case _:
raise Exception(f"Error: Unknown backbone {backbone}")
x = backbone.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x) x = DilatedSpatialPyramidPooling(x)
input_a = tf.keras.layers.UpSampling2D( input_a = tf.keras.layers.UpSampling2D(
size=(image_size // 4 // x.shape[1] * 2, image_size // 4 // x.shape[2] * 2), # <--- UPSAMPLE after pyramid size=(image_size // 4 // x.shape[1] * 2, image_size // 4 // x.shape[2] * 2), # <--- UPSAMPLE after pyramid
interpolation="bilinear", interpolation="bilinear",
)(x) )(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output input_b = backbone.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1) input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b]) x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])