mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-04 17:13:02 +00:00
weights="imagenet" only works with 3 image channels
This commit is contained in:
parent
4563fe6b27
commit
56a501f8a9
1 changed files with 2 additions and 1 deletions
|
@ -100,7 +100,8 @@ def DilatedSpatialPyramidPooling(dspp_input):
|
||||||
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
||||||
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
||||||
resnet50 = tf.keras.applications.ResNet50(
|
resnet50 = tf.keras.applications.ResNet50(
|
||||||
weights="imagenet", include_top=False, input_tensor=model_input
|
weights="imagenet" if num_channels == 3 else None,
|
||||||
|
include_top=False, input_tensor=model_input
|
||||||
)
|
)
|
||||||
x = resnet50.get_layer("conv4_block6_2_relu").output
|
x = resnet50.get_layer("conv4_block6_2_relu").output
|
||||||
x = DilatedSpatialPyramidPooling(x)
|
x = DilatedSpatialPyramidPooling(x)
|
||||||
|
|
Loading…
Reference in a new issue