DeepLabV3+: have argument for number of channels

This commit is contained in:
Starbeamrainbowlabs 2022-12-14 17:36:30 +00:00
parent 1dc2ec3a46
commit 6ce121f861
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -118,8 +118,8 @@ def DilatedSpatialPyramidPooling(dspp_input):
return output
def DeeplabV3Plus(image_size, num_classes):
model_input = tf.keras.Input(shape=(image_size, image_size, 3))
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
resnet50 = tf.keras.applications.ResNet50(
weights="imagenet", include_top=False, input_tensor=model_input
)