mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 01:12:59 +00:00
deeplabv3+: prepare for ConvNeXt
This commit is contained in:
parent
779b546897
commit
e565c36149
1 changed files with 13 additions and 7 deletions
|
@ -135,21 +135,27 @@ if PATH_CHECKPOINT is None:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
def DeeplabV3Plus(image_size, num_classes, num_channels=3, backbone="resnet"):
|
||||||
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
||||||
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
||||||
resnet50 = tf.keras.applications.ResNet50(
|
|
||||||
weights="imagenet" if num_channels == 3 else None,
|
match backbone:
|
||||||
include_top=False, input_tensor=x
|
case "resnet":
|
||||||
)
|
backbone = tf.keras.applications.ResNet50(
|
||||||
x = resnet50.get_layer("conv4_block6_2_relu").output
|
weights="imagenet" if num_channels == 3 else None,
|
||||||
|
include_top=False, input_tensor=x
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise Exception(f"Error: Unknown backbone {backbone}")
|
||||||
|
|
||||||
|
x = backbone.get_layer("conv4_block6_2_relu").output
|
||||||
x = DilatedSpatialPyramidPooling(x)
|
x = DilatedSpatialPyramidPooling(x)
|
||||||
|
|
||||||
input_a = tf.keras.layers.UpSampling2D(
|
input_a = tf.keras.layers.UpSampling2D(
|
||||||
size=(image_size // 4 // x.shape[1] * 2, image_size // 4 // x.shape[2] * 2), # <--- UPSAMPLE after pyramid
|
size=(image_size // 4 // x.shape[1] * 2, image_size // 4 // x.shape[2] * 2), # <--- UPSAMPLE after pyramid
|
||||||
interpolation="bilinear",
|
interpolation="bilinear",
|
||||||
)(x)
|
)(x)
|
||||||
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
input_b = backbone.get_layer("conv2_block3_2_relu").output
|
||||||
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
||||||
|
|
||||||
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
|
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
|
||||||
|
|
Loading…
Reference in a new issue