mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
eo: don't downsample ConvNeXt at beginning
This commit is contained in:
parent
d5fdab50ed
commit
b5e68fc1a3
2 changed files with 9 additions and 3 deletions
|
@ -63,7 +63,12 @@ dataset_train, dataset_validate = dataset_encoderonly(
|
||||||
|
|
||||||
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, **kwargs):
|
def make_encoderonly(windowsize, channels, encoder="convnext", water_bins=2, **kwargs):
|
||||||
if encoder == "convnext":
|
if encoder == "convnext":
|
||||||
model = make_convnext(input_shape=(windowsize, windowsize, channels), num_classes=water_bins, **kwargs)
|
model = make_convnext(
|
||||||
|
input_shape=(windowsize, windowsize, channels),
|
||||||
|
num_classes=water_bins,
|
||||||
|
downsample_at_start=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
elif encoder == "resnet":
|
elif encoder == "resnet":
|
||||||
layer_in = tf.keras.Input(shape=(windowsize, windowsize, channels))
|
layer_in = tf.keras.Input(shape=(windowsize, windowsize, channels))
|
||||||
layer_next = tf.keras.applications.resnet50.ResNet50(
|
layer_next = tf.keras.applications.resnet50.ResNet50(
|
||||||
|
|
|
@ -53,7 +53,8 @@ def convnext(
|
||||||
depths = [3, 3, 9, 3],
|
depths = [3, 3, 9, 3],
|
||||||
dims = [96, 192, 384, 768],
|
dims = [96, 192, 384, 768],
|
||||||
drop_path_rate = 0.,
|
drop_path_rate = 0.,
|
||||||
classifier_activation = 'softmax'
|
classifier_activation = 'softmax',
|
||||||
|
downsample_at_start = True
|
||||||
# Note that we CAN'T add data_format here, 'cause Dense doesn't support specifying the axis
|
# Note that we CAN'T add data_format here, 'cause Dense doesn't support specifying the axis
|
||||||
):
|
):
|
||||||
print("convnext:shape IN x", x.shape)
|
print("convnext:shape IN x", x.shape)
|
||||||
|
@ -67,7 +68,7 @@ def convnext(
|
||||||
x = tf.keras.layers.Conv2D(
|
x = tf.keras.layers.Conv2D(
|
||||||
dim,
|
dim,
|
||||||
kernel_size = 4,
|
kernel_size = 4,
|
||||||
strides = 4,
|
strides = 4 if downsample_at_start else 1,
|
||||||
padding = "valid",
|
padding = "valid",
|
||||||
name = "downsample_layers.0.0_conv"
|
name = "downsample_layers.0.0_conv"
|
||||||
)(x)
|
)(x)
|
||||||
|
|
Loading…
Reference in a new issue