Try to add an additional input to the neural network, but it's not working.

This commit is contained in:
Starbeamrainbowlabs 2019-07-29 14:17:28 +01:00
parent 9e93940225
commit 43f99d6b5d
7 changed files with 121 additions and 37 deletions

View file

@ -1,7 +1,11 @@
{
"ecmaVersion": 8,
"libs": [],
"loadEagerly": [],
"libs": [
"browser"
],
"loadEagerly": [
"server/**.mjs"
],
"dontLoad": [],
"plugins": {
"doc_comment": true

View file

@ -6,6 +6,13 @@ function normalise(value, { min : input_min, max: input_max }, { min : output_mi
) + output_min
}
/**
* Clamps a value to fit within the specified bounds.
* @param {number} value The number to clamp.
* @param {number} min The minimum value the number can be.
* @param {number} max The maximum value the number can be.
* @return {number} The clamped number.
*/
function clamp(value, min, max) {
if(value > max) return max;
if(value < min) return min;

View file

@ -36,6 +36,10 @@ class GatewayRepo {
return count > 0;
}
get_by_id(id) {
return this.db.prepare(`SELECT * FROM gateways WHERE id = :id`).get({ id });
}
/**
* Iterates over all the gateways in the table.
* TODO: Use Symbol.iterator here?

View file

@ -39,6 +39,10 @@ class ReadingRepo {
return count > 0;
}
iterate_unreceived() {
return this.db.prepare(`SELECT * FROM readings WHERE data_rate IS NULL;`).iterate();
}
iterate() {
return this.db.prepare(`SELECT * FROM readings`).iterate();
}

View file

@ -66,6 +66,8 @@ rssi_max = 0
batch_size = 32
# The number of epochs to train for.
epochs = 5
# The percentage, between 0 and 1, of data that should be set aside for validation.
validation_split = 0.1
# The directory to output trained AIs to, relative to the repository root.
output_directory = "app/ais/"

View file

@ -6,7 +6,8 @@ import fs from 'fs';
import tf from '@tensorflow/tfjs-node-gpu';
class AITrainer {
constructor({ settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
constructor({ ansi, settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
this.a = ansi;
this.settings = settings;
this.root_dir = root_dir;
this.l = log;
@ -17,9 +18,9 @@ class AITrainer {
generate_model() {
let model = tf.sequential();
model.add(tf.layers.dense({
units: 256, // 256 nodes
units: 64, // 64 nodes
activation: "sigmoid", // Sigmoid activation function
inputShape: [2], // 2 inputs - lat and long
inputShape: [3], // 3 inputs - lat, long, and distance from gateway
}))
model.add(tf.layers.dense({
units: 1, // 1 output value - RSSI
@ -82,30 +83,36 @@ class AITrainer {
* @return {Promise} A promise that resolves when training and serialisation is complete.
*/
async train_gateway(gateway_id, destination_filename) {
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
let model = this.generate_model();
// TODO: Add samples here for locations that the gateway does NOT cover too
let dataset_input = tf.data.generator(
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
);
let dataset_output = tf.data.generator(
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
);
// let dataset_input = tf.data.generator(
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
// );
// let dataset_output = tf.data.generator(
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
// );
//
// let dataset = tf.data.zip({
// xs: dataset_input,
// ys: dataset_output
// }).shuffle(this.settings.ai.batch_size * 4)
// .batch(this.settings.ai.batch_size);
//
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
let dataset = tf.data.zip({
xs: dataset_input,
ys: dataset_output
}).shuffle(this.settings.ai.batch_size * 4)
.batch(this.settings.ai.batch_size);
let result = await model.fitDataset(dataset, {
let result = await model.fit(
tf.tensor(datasets.input),
tf.tensor(datasets.output),
{
epochs: this.settings.ai.epochs,
batchSize: this.settings.ai.batch_size
});
batchSize: this.settings.ai.batch_size,
validationSplit: this.settings.ai.validation_split
}
);
await model.save(`file://${destination_filename}`);
console.log(result);
// console.log(result);
return true;
}

View file

@ -1,26 +1,78 @@
"use strict";
import haversine from 'haversine-distance';
import { normalise, clamp } from '../../common/Math.mjs';
class DatasetFetcher {
constructor({ settings, RSSIRepo }) {
constructor({ settings, GatewayRepo, RSSIRepo, ReadingRepo }) {
this.settings = settings;
this.repo_gateway = GatewayRepo;
this.repo_rssi = RSSIRepo;
this.repo_reading = ReadingRepo;
}
*fetch_input(gateway_id) {
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
yield [
normalise(rssi.latitude,
normalise_latlng(lat, lng) {
return [
normalise(lat,
{ min: -90, max: +90 },
{ min: 0, max: 1 }
),
normalise(rssi.longitude,
normalise(lng,
{ min: -180, max: +180 },
{ min: 0, max: 1 }
)
];
}
fetch_all(gateway_id) {
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
let result = {
input: [],
output: []
};
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
let distance_from_gateway = haversine(gateway_location, rssi);
next_input.push(clamp(
normalise(distance_from_gateway,
{ min: 0, max: 20000 },
{ min: 0, max: 1 }
),
0, 1))
result.input.push(next_input);
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
result.output.push([
clamp(normalise(rssi.rssi,
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
{ min: 0, max: 1 }
), 0, 1)
]);
}
for(let reading of this.repo_reading.iterate_unreceived()) {
result.input.push(this.normalise_latlng(
reading.latitude,
reading.longitude
));
result.output.push([ 0 ]);
}
return result;
}
*fetch_input(gateway_id) {
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id))
yield this.normalise_latlng(rssi.latitude, rssi.longitude);
for(let rssi of this.repo_reading.iterate_unreceived())
yield this.normalise_latlng(rssi.latitude, rssi.longitude);
}
*fetch_output(gateway_id) {
@ -32,6 +84,10 @@ class DatasetFetcher {
), 0, 1)
];
}
// Yield 0 for every unreceived message, since we want to train it to predict a *terrible* signal where the gateway is not
for(let rssi of this.repo_reading.iterate_unreceived()) {
yield [ 0 ];
}
}
}