From 6a29105f569023d29c3f0f86cfda1abd15d994f0 Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Tue, 25 Oct 2022 21:25:15 +0100 Subject: [PATCH] model_segmentation: stack not reshape --- .../src/lib/ai/components/LayerStack2Image.py | 26 +++++++++++++++++++ .../ai/model_rainfallwater_segmentation.py | 5 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 aimodel/src/lib/ai/components/LayerStack2Image.py diff --git a/aimodel/src/lib/ai/components/LayerStack2Image.py b/aimodel/src/lib/ai/components/LayerStack2Image.py new file mode 100644 index 0000000..4b9d7f9 --- /dev/null +++ b/aimodel/src/lib/ai/components/LayerStack2Image.py @@ -0,0 +1,26 @@ +import tensorflow as tf + +# Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb + +class LayerStack2Image(tf.keras.layers.Layer): + def __init__(self, target_width, target_height, name=None, **kwargs): + super(LayerStack2Image, self).__init__(name=name) + + self.param_target_width = target_width + self.param_target_height = target_height + + def get_config(self): + config = super(LayerStack2Image, self).get_config() + config.update({ + "target_width": self.param_target_width, + "target_height": self.param_target_height, + }) + + return config + + def call(self, input_thing, **kwargs): + result = tf.stack([ input_thing for i in range(self.param_target_width) ], axis=-1) + result = tf.stack([ result for i in range(self.param_target_height) ], axis=-1) + result = tf.stack([ result ], axis=-1) # channel dimension + return result + \ No newline at end of file diff --git a/aimodel/src/lib/ai/model_rainfallwater_segmentation.py b/aimodel/src/lib/ai/model_rainfallwater_segmentation.py index d83b95b..f4d12d6 100644 --- a/aimodel/src/lib/ai/model_rainfallwater_segmentation.py +++ b/aimodel/src/lib/ai/model_rainfallwater_segmentation.py @@ -4,7 +4,7 @@ from loguru import logger import tensorflow as tf from .components.convnext_inverse import do_convnext_inverse - +from .components.LayerStack2Image import LayerStack2Image def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="convnext_i_xtiny", batch_size=64, water_bins=2): """Makes a new rainfall / waterdepth segmentation head model. @@ -31,7 +31,8 @@ def model_rainfallwater_segmentation(metadata, shape_water_out, model_arch="conv layer_next = tf.keras.layers.ReLU(name="cns.stage_begin.relu1")(layer_next) layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm1", epsilon=1e-6)(layer_next) - layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next) + layer_next = LayerStack2Image(target_width=4, target_height=4)(layer_next) + # layer_next = tf.keras.layers.Reshape((4, 4, math.floor(feature_dim_in/(4*4))), name="cns.stable_begin.reshape")(layer_next) layer_next = tf.keras.layers.Dense(name="cns.stage.begin.dense2", units=feature_dim_in)(layer_next) layer_next = tf.keras.layers.ReLU(name="cns.stage_begin.relu2")(layer_next) layer_next = tf.keras.layers.LayerNormalization(name="cns.stage_begin.norm2", epsilon=1e-6)(layer_next)