mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
model_segmentation: stack not reshape
This commit is contained in:
parent
98417a3e06
commit
6a29105f56
2 changed files with 29 additions and 2 deletions
26
aimodel/src/lib/ai/components/LayerStack2Image.py
Normal file
26
aimodel/src/lib/ai/components/LayerStack2Image.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import tensorflow as tf
|
||||
|
||||
# Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb
|
||||
|
||||
class LayerStack2Image(tf.keras.layers.Layer):
|
||||
def __init__(self, target_width, target_height, name=None, **kwargs):
|
||||
super(LayerStack2Image, self).__init__(name=name)
|
||||
|
||||
self.param_target_width = target_width
|
||||
self.param_target_height = target_height
|
||||
|
||||
def get_config(self):
|
||||
config = super(LayerStack2Image, self).get_config()
|
||||
config.update({
|
||||
"target_width": self.param_target_width,
|
||||
"target_height": self.param_target_height,
|
||||
})
|
||||
|
||||
return config
|
||||
|
||||
def call(self, input_thing, **kwargs):
|
||||
result = tf.stack([ input_thing for i in range(self.param_target_width) ], axis=-1)
|
||||
result = tf.stack([ result for i in range(self.param_target_height) ], axis=-1)
|
||||
result = tf.stack([ result ], axis=-1) # channel dimension
|
||||
return result
|
||||
|
|
@ -4,7 +4,7 @@ from loguru import logger
|
|||
import tensorflow as tf
|
||||
|
||||
from .components.convnext_inverse import do_convnext_inverse
|
||||
|
||||
from .components.LayerStack2Image import LayerStack2Image
|
||||
|
||||
def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, water_bins=2):
|
||||
"""Makes a new rainfall / waterdepth segmentation head model.
|
||||
|
@ -31,7 +31,8 @@ def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="conv
|
|||
layer_next = tf.keras.layers.ReLU(name="cns.stage_begin.relu1")(layer_next)
|
||||
layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm1", epsilon=1e-6)(layer_next)
|
||||
|
||||
layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next)
|
||||
layer_next = LayerStack2Image(target_width=4, target_height=4)(layer_next)
|
||||
# layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next)
|
||||
layer_next = tf.keras.layers.Dense(name="cns.stage.begin.dense2", units=feature_dim_in)(layer_next)
|
||||
layer_next = tf.keras.layers.ReLU(name="cns.stage_begin.relu2")(layer_next)
|
||||
layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm2", epsilon=1e-6)(layer_next)
|
||||
|
|
Loading…
Reference in a new issue