model_segmentation: stack not reshape

This commit is contained in:
Starbeamrainbowlabs 2022-10-25 21:25:15 +01:00
parent 98417a3e06
commit 6a29105f56
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 29 additions and 2 deletions

View file

@ -0,0 +1,26 @@
import tensorflow as tf
# Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb
class LayerStack2Image(tf.keras.layers.Layer):
def __init__(self, target_width, target_height, name=None, **kwargs):
super(LayerStack2Image, self).__init__(name=name)
self.param_target_width = target_width
self.param_target_height = target_height
def get_config(self):
config = super(LayerStack2Image, self).get_config()
config.update({
"target_width": self.param_target_width,
"target_height": self.param_target_height,
})
return config
def call(self, input_thing, **kwargs):
result = tf.stack([ input_thing for i in range(self.param_target_width) ], axis=-1)
result = tf.stack([ result for i in range(self.param_target_height) ], axis=-1)
result = tf.stack([ result ], axis=-1) # channel dimension
return result

View file

@ -4,7 +4,7 @@ from loguru import logger
import tensorflow as tf
from .components.convnext_inverse import do_convnext_inverse
from .components.LayerStack2Image import LayerStack2Image
def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, water_bins=2):
"""Makes a new rainfall / waterdepth segmentation head model.
@ -31,7 +31,8 @@ def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="conv
layer_next = tf.keras.layers.ReLU(name="cns.stage_begin.relu1")(layer_next)
layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm1", epsilon=1e-6)(layer_next)
layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next)
layer_next = LayerStack2Image(target_width=4, target_height=4)(layer_next)
# layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next)
layer_next = tf.keras.layers.Dense(name="cns.stage.begin.dense2", units=feature_dim_in)(layer_next)
layer_next = tf.keras.layers.ReLU(name="cns.stage_begin.relu2")(layer_next)
layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm2", epsilon=1e-6)(layer_next)