Fix bugs in AI trainer. It works, but precisely it's doing si up for debate....

This commit is contained in:
Starbeamrainbowlabs 2019-07-18 17:22:37 +01:00
parent 192e31a925
commit 5c7cc1e906
4 changed files with 21 additions and 10 deletions

View file

@ -29,7 +29,7 @@ class RSSIRepo {
readings.latitude, readings.latitude,
readings.longitude readings.longitude
FROM rssis FROM rssis
JOIN readings ON rssis.gateway_id = readings.id JOIN readings ON rssis.reading_id = readings.id
WHERE gateway_id = :gateway_id`).iterate({ WHERE gateway_id = :gateway_id`).iterate({
gateway_id gateway_id
}); });

View file

@ -77,6 +77,12 @@ export default async function(c) {
break; break;
case "train-ai":
l.log(`${a.fgreen}${a.hicol}Training AIs${a.reset}`);
let ai_trainer = c.cradle.AITrainer;
await ai_trainer.train_all();
break;
default: default:
l.error(`Error: Subcommand '${extras[0]}' not recognised.`); l.error(`Error: Subcommand '${extras[0]}' not recognised.`);
l.error(`Perhaps you mistyped it, or it hasn't been implemented yet?`); l.error(`Perhaps you mistyped it, or it hasn't been implemented yet?`);

View file

@ -38,7 +38,7 @@ export default async function setup() {
DataProcessor: a.asClass(DataProcessor), DataProcessor: a.asClass(DataProcessor),
AITrainer: a.asClass(AITrainer), AITrainer: a.asClass(AITrainer),
DatasetTrainer: a.asClass(DatasetTrainer) DatasetFetcher: a.asClass(DatasetFetcher)
}); });
// Enable / disable colourising the output // Enable / disable colourising the output

View file

@ -3,8 +3,9 @@
import tf from '@tensorflow/tfjs-node-gpu'; import tf from '@tensorflow/tfjs-node-gpu';
class AITrainer { class AITrainer {
constructor({ settings, GatewayRepo, DatasetFetcher }) { constructor({ settings, log, GatewayRepo, DatasetFetcher }) {
this.settings = settings; this.settings = settings;
this.l = log;
this.dataset_fetcher = DatasetFetcher; this.dataset_fetcher = DatasetFetcher;
this.repo_gateway = GatewayRepo; this.repo_gateway = GatewayRepo;
this.model = this.generate_model(); this.model = this.generate_model();
@ -15,7 +16,7 @@ class AITrainer {
model.add(tf.layers.dense({ model.add(tf.layers.dense({
units: 256, // 256 nodes units: 256, // 256 nodes
activation: "sigmoid", // Sigmoid activation function activation: "sigmoid", // Sigmoid activation function
inputShape: [3], // 2 inputs - lat and long inputShape: [2], // 2 inputs - lat and long
})) }))
model.add(tf.layers.dense({ model.add(tf.layers.dense({
units: 1, // 1 output value - RSSI units: 1, // 1 output value - RSSI
@ -24,33 +25,37 @@ class AITrainer {
model.compile({ model.compile({
optimizer: tf.train.adam(), optimizer: tf.train.adam(),
loss: "absoluteDifference", loss: tf.losses.absoluteDifference,
metrics: [ "accuracy", "meanSquaredError" ] metrics: [ tf.metrics.meanSquaredError ]
}); });
this.l.log(`Model:`);
model.summary();
return model; return model;
} }
async train_all() { async train_all() {
for(let gateway of this.repo_gateway.iterate()) { for(let gateway of this.repo_gateway.iterate()) {
await this.train_dataset(gateway.id); await this.train_gateway(gateway.id);
} }
} }
async train_gateway(gateway_id) { async train_gateway(gateway_id) {
let dataset_input = tf.data.generator( let dataset_input = tf.data.generator(
this.dataset_fetcher.fetch_input.bind(null, gateway_id) this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
); );
let dataset_output = tf.data.generator( let dataset_output = tf.data.generator(
this.dataset_fetcher.fetch_output.bind(null, gateway_id) this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
); );
let dataset = tf.data.zip({ let dataset = tf.data.zip({
xs: dataset_input, xs: dataset_input,
ys: dataset_output ys: dataset_output
}).shuffle(this.settings.ai.batch_size) }).shuffle(this.settings.ai.batch_size * 4)
.batch(this.settings.ai.batch_size); .batch(this.settings.ai.batch_size);
let result = await this.model.fitDataset(dataset, { let result = await this.model.fitDataset(dataset, {
epochs: this.settings.ai.epochs, epochs: this.settings.ai.epochs,
batchSize: this.settings.ai.batch_size batchSize: this.settings.ai.batch_size