From 56a501f8a972c9c53ce6c96916d0bd1f53698455 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 5 Jan 2023 19:09:31 +0000 Subject: [PATCH] weights="imagenet" only works with 3 image channels --- aimodel/src/deeplabv3_plus_test_rainfall.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index b410cc9..e5c6dd8 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -100,7 +100,8 @@ def DilatedSpatialPyramidPooling(dspp_input): def DeeplabV3Plus(image_size, num_classes, num_channels=3): model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) resnet50 = tf.keras.applications.ResNet50( - weights="imagenet", include_top=False, input_tensor=model_input + weights="imagenet" if num_channels == 3 else None, + include_top=False, input_tensor=model_input ) x = resnet50.get_layer("conv4_block6_2_relu").output x = DilatedSpatialPyramidPooling(x)