fix embedding confusion

This commit is contained in:
Starbeamrainbowlabs 2022-10-21 15:15:59 +01:00
parent 847cd97ec4
commit 3f7db6fa78
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 12 additions and 10 deletions

View file

@ -83,14 +83,11 @@ class RainfallWaterSegmenter(object):
steps_per_epoch=10 # For testing steps_per_epoch=10 # For testing
) )
def embed(self, dataset): def embed(self, rainfall_embed):
i_batch = -1 rainfall = self.model(rainfall_embed, training=False) # (rainfall_embed, water)
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
i_batch += 1 for step in tf.unstack(rainfall, axis=0):
rainfall = self.model(batch[0], training=False) # (rainfall_embed, water) yield step
for step in tf.unstack(rainfall, axis=0):
yield step
# def embed_rainfall(self, dataset): # def embed_rainfall(self, dataset):

View file

@ -94,19 +94,24 @@ def do_png(args, ai, dataset, model_code):
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"]) gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
for item in gen: for item in gen:
rainfall, water = item rainfall, water = item
water_predict_batch = ai.embed(rainfall) water_predict_batch = ai.embed(rainfall)
water = tf.unstack(tf.squeeze(water), axis=0)
i_batch = 0
for water_predict in water_predict_batch: for water_predict in water_predict_batch:
# [ width, height, softmax_probabilities ] → [ batch, width, height ] # [ width, height, softmax_probabilities ] → [ batch, width, height ]
water_predict = tf.math.argmax(water_predict, axis=-1) water_predict = tf.math.argmax(water_predict, axis=-1)
# [ width, height ] # [ width, height ]
water = tf.squeeze(water) water_item = water[i_batch]
segmentation_plot( segmentation_plot(
water, water_predict, water_item, water_predict,
model_code, model_code,
args.output.replace("+d", str(i)) args.output.replace("+d", str(i))
) )
i_batch += 1
i += 1 i += 1
if i % 100 == 0: if i % 100 == 0: