From 192e31a92533395fb5b01e665b5a3c908321f33a Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Thu, 18 Jul 2019 16:34:25 +0100 Subject: [PATCH] Write the the data input, but it's untested. There's also some confusion over the batch size. If it doesn't work, we probably need to revise our understanding of the batch size in different contexts. --- server/settings.default.toml | 8 ++++++ server/train-ai/AITrainer.mjs | 26 +++++++++++++---- server/train-ai/DatasetFetcher.mjs | 45 +++++++++++++++--------------- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/server/settings.default.toml b/server/settings.default.toml index 76cebe0..c828646 100644 --- a/server/settings.default.toml +++ b/server/settings.default.toml @@ -59,6 +59,14 @@ devices = [ rssi_min = -150 rssi_max = 0 +# Data is streamed from the SQLite databse. The batch size specifies how many +# rows to gather for training at once. +# Note that the entire available dataset will eventually end up being used - +# this setting just controls how much of ti is in memory at once. +batch_size = 32 +# The number of epochs to train for. +epochs = 5 + [logging] # The format the date displayed when logging things should take. # Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely)) diff --git a/server/train-ai/AITrainer.mjs b/server/train-ai/AITrainer.mjs index e056b28..d5100e9 100644 --- a/server/train-ai/AITrainer.mjs +++ b/server/train-ai/AITrainer.mjs @@ -31,15 +31,31 @@ class AITrainer { return model; } - async train() { + async train_all() { for(let gateway of this.repo_gateway.iterate()) { - let dataset = this.dataset_fetcher.fetch(gateway.id); - await this.train_dataset(dataset); + await this.train_dataset(gateway.id); } } - async train_dataset(dataset) { - // TODO: Fill this in + async train_gateway(gateway_id) { + let dataset_input = tf.data.generator( + this.dataset_fetcher.fetch_input.bind(null, gateway_id) + ); + let dataset_output = tf.data.generator( + this.dataset_fetcher.fetch_output.bind(null, gateway_id) + ); + + let dataset = tf.data.zip({ + xs: dataset_input, + ys: dataset_output + }).shuffle(this.settings.ai.batch_size) + .batch(this.settings.ai.batch_size); + + let result = await this.model.fitDataset(dataset, { + epochs: this.settings.ai.epochs, + batchSize: this.settings.ai.batch_size + }); + console.log(result); } } diff --git a/server/train-ai/DatasetFetcher.mjs b/server/train-ai/DatasetFetcher.mjs index bd843fc..e5b0d87 100644 --- a/server/train-ai/DatasetFetcher.mjs +++ b/server/train-ai/DatasetFetcher.mjs @@ -8,29 +8,30 @@ class DatasetFetcher { this.repo_rssi = RSSIRepo; } - fetch(gateway_id) { - let result = []; - for(let rssi of this.repo_rssi.iterate_gateway(gateway_id) { - result.push({ - input: [ - normalise(rssi.latitude, - { min: -90, max: +90 }, - { min: 0, max: 1 } - ), - normalise(rssi.longitude, - { min: -180, max: +180 }, - { min: 0, max: 1 } - ) - ], - output: [ - clamp(normalise(rssis.rssi, - { min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max }, - { min: 0, max: 1 } - ), 0, 1) - ] - }); + *fetch_input(gateway_id) { + for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) { + yield [ + normalise(rssi.latitude, + { min: -90, max: +90 }, + { min: 0, max: 1 } + ), + normalise(rssi.longitude, + { min: -180, max: +180 }, + { min: 0, max: 1 } + ) + ]; + } + } + + *fetch_output(gateway_id) { + for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) { + yield [ + clamp(normalise(rssi.rssi, + { min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max }, + { min: 0, max: 1 } + ), 0, 1) + ]; } - return result; } }