eo: don't downsample ConvNeXt at beginning

This commit is contained in:
Starbeamrainbowlabs 2023-01-20 18:49:46 +00:00
parent d5fdab50ed
commit b5e68fc1a3
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 9 additions and 3 deletions

View file

@ -63,7 +63,12 @@ dataset_train, dataset_validate = dataset_encoderonly(
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, **kwargs):
if encoder == "convnext":
model = make_convnext(input_shape=(windowsize, windowsize, channels), num_classes=water_bins, **kwargs)
model = make_convnext(
input_shape=(windowsize, windowsize, channels),
num_classes=water_bins,
downsample_at_start=False,
**kwargs
)
elif encoder == "resnet":
layer_in = tf.keras.Input(shape=(windowsize, windowsize, channels))
layer_next = tf.keras.applications.resnet50.ResNet50(

View file

@ -53,7 +53,8 @@ def convnext(
depths = [3, 3, 9, 3],
dims = [96, 192, 384, 768],
drop_path_rate = 0.,
classifier_activation = 'softmax'
classifier_activation = 'softmax',
downsample_at_start = True
# Note that we CAN'T add data_format here, 'cause Dense doesn't support specifying the axis
):
print("convnext:shape IN x", x.shape)
@ -67,7 +68,7 @@ def convnext(
x = tf.keras.layers.Conv2D(
dim,
kernel_size = 4,
strides = 4,
strides = 4 if downsample_at_start else 1,
padding = "valid",
name = "downsample_layers.0.0_conv"
)(x)