weights="imagenet" only works with 3 image channels

This commit is contained in:
Starbeamrainbowlabs 2023-01-05 19:09:31 +00:00
parent 4563fe6b27
commit 56a501f8a9
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -100,7 +100,8 @@ def DilatedSpatialPyramidPooling(dspp_input):
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
resnet50 = tf.keras.applications.ResNet50(
weights="imagenet", include_top=False, input_tensor=model_input
weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=model_input
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)