mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
fix embedding confusion
This commit is contained in:
parent
847cd97ec4
commit
3f7db6fa78
2 changed files with 12 additions and 10 deletions
|
@ -83,14 +83,11 @@ class RainfallWaterSegmenter(object):
|
|||
steps_per_epoch=10 # For testing
|
||||
)
|
||||
|
||||
def embed(self, dataset):
|
||||
i_batch = -1
|
||||
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
|
||||
i_batch += 1
|
||||
rainfall = self.model(batch[0], training=False) # (rainfall_embed, water)
|
||||
|
||||
for step in tf.unstack(rainfall, axis=0):
|
||||
yield step
|
||||
def embed(self, rainfall_embed):
|
||||
rainfall = self.model(rainfall_embed, training=False) # (rainfall_embed, water)
|
||||
|
||||
for step in tf.unstack(rainfall, axis=0):
|
||||
yield step
|
||||
|
||||
|
||||
# def embed_rainfall(self, dataset):
|
||||
|
|
|
@ -94,19 +94,24 @@ def do_png(args, ai, dataset, model_code):
|
|||
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
|
||||
for item in gen:
|
||||
rainfall, water = item
|
||||
|
||||
water_predict_batch = ai.embed(rainfall)
|
||||
water = tf.unstack(tf.squeeze(water), axis=0)
|
||||
|
||||
i_batch = 0
|
||||
for water_predict in water_predict_batch:
|
||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||
# [ width, height ]
|
||||
water = tf.squeeze(water)
|
||||
water_item = water[i_batch]
|
||||
|
||||
segmentation_plot(
|
||||
water, water_predict,
|
||||
water_item, water_predict,
|
||||
model_code,
|
||||
args.output.replace("+d", str(i))
|
||||
)
|
||||
|
||||
i_batch += 1
|
||||
i += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
|
|
Loading…
Reference in a new issue