fix embedding confusion

This commit is contained in:
Starbeamrainbowlabs 2022-10-21 15:15:59 +01:00
parent 847cd97ec4
commit 3f7db6fa78
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 12 additions and 10 deletions

View file

@ -83,14 +83,11 @@ class RainfallWaterSegmenter(object):
steps_per_epoch=10 # For testing
)
def embed(self, dataset):
i_batch = -1
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
i_batch += 1
rainfall = self.model(batch[0], training=False) # (rainfall_embed, water)
def embed(self, rainfall_embed):
rainfall = self.model(rainfall_embed, training=False) # (rainfall_embed, water)
for step in tf.unstack(rainfall, axis=0):
yield step
for step in tf.unstack(rainfall, axis=0):
yield step
# def embed_rainfall(self, dataset):

View file

@ -94,19 +94,24 @@ def do_png(args, ai, dataset, model_code):
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
for item in gen:
rainfall, water = item
water_predict_batch = ai.embed(rainfall)
water = tf.unstack(tf.squeeze(water), axis=0)
i_batch = 0
for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ]
water = tf.squeeze(water)
water_item = water[i_batch]
segmentation_plot(
water, water_predict,
water_item, water_predict,
model_code,
args.output.replace("+d", str(i))
)
i_batch += 1
i += 1
if i % 100 == 0: