diff --git a/aimodel/src/deeplabv3_plus_test_rainfall.py b/aimodel/src/deeplabv3_plus_test_rainfall.py index b410cc9..e5c6dd8 100755 --- a/aimodel/src/deeplabv3_plus_test_rainfall.py +++ b/aimodel/src/deeplabv3_plus_test_rainfall.py @@ -100,7 +100,8 @@ def DilatedSpatialPyramidPooling(dspp_input): def DeeplabV3Plus(image_size, num_classes, num_channels=3): model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) resnet50 = tf.keras.applications.ResNet50( - weights="imagenet", include_top=False, input_tensor=model_input + weights="imagenet" if num_channels == 3 else None, + include_top=False, input_tensor=model_input ) x = resnet50.get_layer("conv4_block6_2_relu").output x = DilatedSpatialPyramidPooling(x)