dataset: explicit reshape

This commit is contained in:
Starbeamrainbowlabs 2022-09-02 16:57:59 +01:00
parent c066ea331b
commit c89677abd7
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -22,6 +22,9 @@ def parse_item(metadata):
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
# [channels, width, height] → [width, height, channels] - ref ConvNeXt does not support data_format=channels_first
rainfall = tf.reshape(rainfall, tf.constant(metadata["rainfallradar"], dtype=tf.int32))
water = tf.reshape(water, tf.constant(metadata["waterdepth"], dtype=tf.int32))
rainfall = tf.transpose(rainfall, [1, 2, 0])
# [width, height] → [width, height, channels]
water = tf.expand_dims(water, axis=-1)