mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-25 10:32:59 +00:00
dlr eo: cheese it by upsampling and then downsampling again
This commit is contained in:
parent
96b94ec55b
commit
9f1cee2927
2 changed files with 9 additions and 8 deletions
|
@ -121,15 +121,16 @@ if PATH_CHECKPOINT is None:
|
||||||
|
|
||||||
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
def DeeplabV3Plus(image_size, num_classes, num_channels=3):
|
||||||
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
|
||||||
|
x = tf.keras.layers.UpSampling2D(size=2)(model_input)
|
||||||
resnet50 = tf.keras.applications.ResNet50(
|
resnet50 = tf.keras.applications.ResNet50(
|
||||||
weights="imagenet" if num_channels == 3 else None,
|
weights="imagenet" if num_channels == 3 else None,
|
||||||
include_top=False, input_tensor=model_input
|
include_top=False, input_tensor=x
|
||||||
)
|
)
|
||||||
x = resnet50.get_layer("conv4_block6_2_relu").output
|
x = resnet50.get_layer("conv4_block6_2_relu").output
|
||||||
x = DilatedSpatialPyramidPooling(x)
|
x = DilatedSpatialPyramidPooling(x)
|
||||||
|
|
||||||
input_a = tf.keras.layers.UpSampling2D(
|
input_a = tf.keras.layers.UpSampling2D(
|
||||||
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
|
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), # <--- UPSAMPLE after pyramid
|
||||||
interpolation="bilinear",
|
interpolation="bilinear",
|
||||||
)(x)
|
)(x)
|
||||||
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
||||||
|
@ -139,7 +140,7 @@ if PATH_CHECKPOINT is None:
|
||||||
x = convolution_block(x)
|
x = convolution_block(x)
|
||||||
x = convolution_block(x)
|
x = convolution_block(x)
|
||||||
x = tf.keras.layers.UpSampling2D(
|
x = tf.keras.layers.UpSampling2D(
|
||||||
size=(image_size // x.shape[1], image_size // x.shape[2]),
|
size=(image_size // x.shape[1], image_size // x.shape[2]), # <--- UPSAMPLE at end
|
||||||
interpolation="bilinear",
|
interpolation="bilinear",
|
||||||
)(x)
|
)(x)
|
||||||
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
||||||
|
|
|
@ -94,7 +94,7 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
# ONE-HOT [LOSS cross entropy]
|
# ONE-HOT [LOSS cross entropy]
|
||||||
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
# water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.int32)
|
||||||
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
|
# water = tf.one_hot(water, water_bins, axis=-1, dtype=tf.int32)
|
||||||
# SPARSE [LOSS dice]
|
# SPARSE [LOSS dice / sparse cross entropy]
|
||||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
||||||
if do_remove_isolated_pixels:
|
if do_remove_isolated_pixels:
|
||||||
water = remove_isolated_pixels(water)
|
water = remove_isolated_pixels(water)
|
||||||
|
|
Loading…
Reference in a new issue