mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
eo: don't downsample ConvNeXt at beginning
This commit is contained in:
parent
d5fdab50ed
commit
b5e68fc1a3
2 changed files with 9 additions and 3 deletions
|
@ -63,7 +63,12 @@ dataset_train, dataset_validate = dataset_encoderonly(
|
|||
|
||||
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, **kwargs):
|
||||
if encoder == "convnext":
|
||||
model = make_convnext(input_shape=(windowsize, windowsize, channels), num_classes=water_bins, **kwargs)
|
||||
model = make_convnext(
|
||||
input_shape=(windowsize, windowsize, channels),
|
||||
num_classes=water_bins,
|
||||
downsample_at_start=False,
|
||||
**kwargs
|
||||
)
|
||||
elif encoder == "resnet":
|
||||
layer_in = tf.keras.Input(shape=(windowsize, windowsize, channels))
|
||||
layer_next = tf.keras.applications.resnet50.ResNet50(
|
||||
|
|
|
@ -53,7 +53,8 @@ def convnext(
|
|||
depths = [3, 3, 9, 3],
|
||||
dims = [96, 192, 384, 768],
|
||||
drop_path_rate = 0.,
|
||||
classifier_activation = 'softmax'
|
||||
classifier_activation = 'softmax',
|
||||
downsample_at_start = True
|
||||
# Note that we CAN'T add data_format here, 'cause Dense doesn't support specifying the axis
|
||||
):
|
||||
print("convnext:shape IN x", x.shape)
|
||||
|
@ -67,7 +68,7 @@ def convnext(
|
|||
x = tf.keras.layers.Conv2D(
|
||||
dim,
|
||||
kernel_size = 4,
|
||||
strides = 4,
|
||||
strides = 4 if downsample_at_start else 1,
|
||||
padding = "valid",
|
||||
name = "downsample_layers.0.0_conv"
|
||||
)(x)
|
||||
|
|
Loading…
Reference in a new issue