mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
fix embedding confusion
This commit is contained in:
parent
847cd97ec4
commit
3f7db6fa78
2 changed files with 12 additions and 10 deletions
|
@ -83,14 +83,11 @@ class RainfallWaterSegmenter(object):
|
||||||
steps_per_epoch=10 # For testing
|
steps_per_epoch=10 # For testing
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed(self, dataset):
|
def embed(self, rainfall_embed):
|
||||||
i_batch = -1
|
rainfall = self.model(rainfall_embed, training=False) # (rainfall_embed, water)
|
||||||
for batch in batched_iterator(dataset, tensors_in_item=2, batch_size=self.batch_size):
|
|
||||||
i_batch += 1
|
for step in tf.unstack(rainfall, axis=0):
|
||||||
rainfall = self.model(batch[0], training=False) # (rainfall_embed, water)
|
yield step
|
||||||
|
|
||||||
for step in tf.unstack(rainfall, axis=0):
|
|
||||||
yield step
|
|
||||||
|
|
||||||
|
|
||||||
# def embed_rainfall(self, dataset):
|
# def embed_rainfall(self, dataset):
|
||||||
|
|
|
@ -94,19 +94,24 @@ def do_png(args, ai, dataset, model_code):
|
||||||
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
|
gen = batched_iterator(dataset, tensors_in_item=2, batch_size=model_params["batch_size"])
|
||||||
for item in gen:
|
for item in gen:
|
||||||
rainfall, water = item
|
rainfall, water = item
|
||||||
|
|
||||||
water_predict_batch = ai.embed(rainfall)
|
water_predict_batch = ai.embed(rainfall)
|
||||||
|
water = tf.unstack(tf.squeeze(water), axis=0)
|
||||||
|
|
||||||
|
i_batch = 0
|
||||||
for water_predict in water_predict_batch:
|
for water_predict in water_predict_batch:
|
||||||
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
# [ width, height, softmax_probabilities ] → [ batch, width, height ]
|
||||||
water_predict = tf.math.argmax(water_predict, axis=-1)
|
water_predict = tf.math.argmax(water_predict, axis=-1)
|
||||||
# [ width, height ]
|
# [ width, height ]
|
||||||
water = tf.squeeze(water)
|
water_item = water[i_batch]
|
||||||
|
|
||||||
segmentation_plot(
|
segmentation_plot(
|
||||||
water, water_predict,
|
water_item, water_predict,
|
||||||
model_code,
|
model_code,
|
||||||
args.output.replace("+d", str(i))
|
args.output.replace("+d", str(i))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
i_batch += 1
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if i % 100 == 0:
|
if i % 100 == 0:
|
||||||
|
|
Loading…
Reference in a new issue