dlr eo: cheese it by upsampling and then downsampling again

This commit is contained in:
Starbeamrainbowlabs 2023-02-23 16:47:00 +00:00
parent 96b94ec55b
commit 9f1cee2927
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 9 additions and 8 deletions

View file

@ -108,28 +108,29 @@ if PATH_CHECKPOINT is None:
out_pool = tf.keras.layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
resnet50 = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=model_input
include_top=False, input_tensor=x
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = tf.keras.layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), # <--- UPSAMPLE after pyramid
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
@ -139,7 +140,7 @@ if PATH_CHECKPOINT is None:
x = convolution_block(x)
x = convolution_block(x)
x = tf.keras.layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]),
size=(image_size // x.shape[1], image_size // x.shape[2]), # <--- UPSAMPLE at end
interpolation="bilinear",
)(x)
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)

View file

@ -94,7 +94,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
# ONE-HOT [LOSS cross entropy]
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
# SPARSE [LOSS dice]
# SPARSE [LOSS dice / sparse cross entropy]
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
if do_remove_isolated_pixels:
water = remove_isolated_pixels(water)