ai: predict oops

This commit is contained in:
Starbeamrainbowlabs 2022-09-14 17:37:48 +01:00
parent fa3165a5b2
commit a96cefde62
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 10 additions and 16 deletions

View file

@ -88,18 +88,15 @@ class RainfallWaterContraster(object):
i_batch = -1 i_batch = -1
for batch in dataset: for batch in dataset:
i_batch += 1 i_batch += 1
result_batch = self.model_predict(batch[0]) # ((rainfall, water), dummy_label) rainfall = self.model_predict(batch) # ((rainfall, water), dummy_label)
rainfall, water = tf.unstack(result_batch, axis=-2)
rainfall = tf.unstack(rainfall, axis=0) for step in tf.unstack(rainfall, axis=0):
water = tf.unstack(water, axis=0)
for step in zip(rainfall, water):
yield step yield step
def embed_rainfall(self, dataset): # def embed_rainfall(self, dataset):
result = [] # result = []
for batch in dataset: # for batch in dataset:
result_batch = self.model_predict(batch) # result_batch = self.model_predict(batch)
result.extend(tf.unstack(result_batch, axis=0)) # result.extend(tf.unstack(result_batch, axis=0))
return result # return result

View file

@ -68,11 +68,8 @@ def run(args):
handle = handle_open(filepath_output, "w") handle = handle_open(filepath_output, "w")
i = 0 i = 0
for rainfall, water in ai.embed(dataset): for rainfall in ai.embed(dataset):
handle.write(json.dumps({ handle.write(json.dumps(rainfall.numpy().tolist(), separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
"rainfall": rainfall.numpy().tolist(),
"water": water.numpy().tolist()
}, separators=(',', ':'))+"\n") # Ref https://stackoverflow.com/a/64710892/1460422
if i == 0 or i % 1000 == 0: if i == 0 or i % 1000 == 0:
sys.stderr.write(f"[pretrain:predict] STEP {i}\r") sys.stderr.write(f"[pretrain:predict] STEP {i}\r")