Try to add an additional input to the neural network, but it's not working.
This commit is contained in:
parent
9e93940225
commit
43f99d6b5d
7 changed files with 121 additions and 37 deletions
|
@ -1,7 +1,11 @@
|
||||||
{
|
{
|
||||||
"ecmaVersion": 8,
|
"ecmaVersion": 8,
|
||||||
"libs": [],
|
"libs": [
|
||||||
"loadEagerly": [],
|
"browser"
|
||||||
|
],
|
||||||
|
"loadEagerly": [
|
||||||
|
"server/**.mjs"
|
||||||
|
],
|
||||||
"dontLoad": [],
|
"dontLoad": [],
|
||||||
"plugins": {
|
"plugins": {
|
||||||
"doc_comment": true
|
"doc_comment": true
|
||||||
|
|
|
@ -6,6 +6,13 @@ function normalise(value, { min : input_min, max: input_max }, { min : output_mi
|
||||||
) + output_min
|
) + output_min
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clamps a value to fit within the specified bounds.
|
||||||
|
* @param {number} value The number to clamp.
|
||||||
|
* @param {number} min The minimum value the number can be.
|
||||||
|
* @param {number} max The maximum value the number can be.
|
||||||
|
* @return {number} The clamped number.
|
||||||
|
*/
|
||||||
function clamp(value, min, max) {
|
function clamp(value, min, max) {
|
||||||
if(value > max) return max;
|
if(value > max) return max;
|
||||||
if(value < min) return min;
|
if(value < min) return min;
|
||||||
|
|
|
@ -36,6 +36,10 @@ class GatewayRepo {
|
||||||
return count > 0;
|
return count > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get_by_id(id) {
|
||||||
|
return this.db.prepare(`SELECT * FROM gateways WHERE id = :id`).get({ id });
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Iterates over all the gateways in the table.
|
* Iterates over all the gateways in the table.
|
||||||
* TODO: Use Symbol.iterator here?
|
* TODO: Use Symbol.iterator here?
|
||||||
|
|
|
@ -39,6 +39,10 @@ class ReadingRepo {
|
||||||
return count > 0;
|
return count > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
iterate_unreceived() {
|
||||||
|
return this.db.prepare(`SELECT * FROM readings WHERE data_rate IS NULL;`).iterate();
|
||||||
|
}
|
||||||
|
|
||||||
iterate() {
|
iterate() {
|
||||||
return this.db.prepare(`SELECT * FROM readings`).iterate();
|
return this.db.prepare(`SELECT * FROM readings`).iterate();
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,6 +66,8 @@ rssi_max = 0
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
# The number of epochs to train for.
|
# The number of epochs to train for.
|
||||||
epochs = 5
|
epochs = 5
|
||||||
|
# The percentage, between 0 and 1, of data that should be set aside for validation.
|
||||||
|
validation_split = 0.1
|
||||||
|
|
||||||
# The directory to output trained AIs to, relative to the repository root.
|
# The directory to output trained AIs to, relative to the repository root.
|
||||||
output_directory = "app/ais/"
|
output_directory = "app/ais/"
|
||||||
|
|
|
@ -6,7 +6,8 @@ import fs from 'fs';
|
||||||
import tf from '@tensorflow/tfjs-node-gpu';
|
import tf from '@tensorflow/tfjs-node-gpu';
|
||||||
|
|
||||||
class AITrainer {
|
class AITrainer {
|
||||||
constructor({ settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
|
constructor({ ansi, settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
|
||||||
|
this.a = ansi;
|
||||||
this.settings = settings;
|
this.settings = settings;
|
||||||
this.root_dir = root_dir;
|
this.root_dir = root_dir;
|
||||||
this.l = log;
|
this.l = log;
|
||||||
|
@ -17,9 +18,9 @@ class AITrainer {
|
||||||
generate_model() {
|
generate_model() {
|
||||||
let model = tf.sequential();
|
let model = tf.sequential();
|
||||||
model.add(tf.layers.dense({
|
model.add(tf.layers.dense({
|
||||||
units: 256, // 256 nodes
|
units: 64, // 64 nodes
|
||||||
activation: "sigmoid", // Sigmoid activation function
|
activation: "sigmoid", // Sigmoid activation function
|
||||||
inputShape: [2], // 2 inputs - lat and long
|
inputShape: [3], // 3 inputs - lat, long, and distance from gateway
|
||||||
}))
|
}))
|
||||||
model.add(tf.layers.dense({
|
model.add(tf.layers.dense({
|
||||||
units: 1, // 1 output value - RSSI
|
units: 1, // 1 output value - RSSI
|
||||||
|
@ -82,30 +83,36 @@ class AITrainer {
|
||||||
* @return {Promise} A promise that resolves when training and serialisation is complete.
|
* @return {Promise} A promise that resolves when training and serialisation is complete.
|
||||||
*/
|
*/
|
||||||
async train_gateway(gateway_id, destination_filename) {
|
async train_gateway(gateway_id, destination_filename) {
|
||||||
|
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
||||||
let model = this.generate_model();
|
let model = this.generate_model();
|
||||||
|
|
||||||
// TODO: Add samples here for locations that the gateway does NOT cover too
|
// let dataset_input = tf.data.generator(
|
||||||
let dataset_input = tf.data.generator(
|
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||||
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
// );
|
||||||
|
// let dataset_output = tf.data.generator(
|
||||||
|
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// let dataset = tf.data.zip({
|
||||||
|
// xs: dataset_input,
|
||||||
|
// ys: dataset_output
|
||||||
|
// }).shuffle(this.settings.ai.batch_size * 4)
|
||||||
|
// .batch(this.settings.ai.batch_size);
|
||||||
|
//
|
||||||
|
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
|
||||||
|
|
||||||
|
let result = await model.fit(
|
||||||
|
tf.tensor(datasets.input),
|
||||||
|
tf.tensor(datasets.output),
|
||||||
|
{
|
||||||
|
epochs: this.settings.ai.epochs,
|
||||||
|
batchSize: this.settings.ai.batch_size,
|
||||||
|
validationSplit: this.settings.ai.validation_split
|
||||||
|
}
|
||||||
);
|
);
|
||||||
let dataset_output = tf.data.generator(
|
|
||||||
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
let dataset = tf.data.zip({
|
|
||||||
xs: dataset_input,
|
|
||||||
ys: dataset_output
|
|
||||||
}).shuffle(this.settings.ai.batch_size * 4)
|
|
||||||
.batch(this.settings.ai.batch_size);
|
|
||||||
|
|
||||||
|
|
||||||
let result = await model.fitDataset(dataset, {
|
|
||||||
epochs: this.settings.ai.epochs,
|
|
||||||
batchSize: this.settings.ai.batch_size
|
|
||||||
});
|
|
||||||
|
|
||||||
await model.save(`file://${destination_filename}`);
|
await model.save(`file://${destination_filename}`);
|
||||||
console.log(result);
|
// console.log(result);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,26 +1,78 @@
|
||||||
"use strict";
|
"use strict";
|
||||||
|
|
||||||
|
import haversine from 'haversine-distance';
|
||||||
|
|
||||||
import { normalise, clamp } from '../../common/Math.mjs';
|
import { normalise, clamp } from '../../common/Math.mjs';
|
||||||
|
|
||||||
class DatasetFetcher {
|
class DatasetFetcher {
|
||||||
constructor({ settings, RSSIRepo }) {
|
constructor({ settings, GatewayRepo, RSSIRepo, ReadingRepo }) {
|
||||||
this.settings = settings;
|
this.settings = settings;
|
||||||
|
this.repo_gateway = GatewayRepo;
|
||||||
this.repo_rssi = RSSIRepo;
|
this.repo_rssi = RSSIRepo;
|
||||||
|
this.repo_reading = ReadingRepo;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalise_latlng(lat, lng) {
|
||||||
|
return [
|
||||||
|
normalise(lat,
|
||||||
|
{ min: -90, max: +90 },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
),
|
||||||
|
normalise(lng,
|
||||||
|
{ min: -180, max: +180 },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
)
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch_all(gateway_id) {
|
||||||
|
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
|
||||||
|
|
||||||
|
let result = {
|
||||||
|
input: [],
|
||||||
|
output: []
|
||||||
|
};
|
||||||
|
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||||
|
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||||
|
let distance_from_gateway = haversine(gateway_location, rssi);
|
||||||
|
|
||||||
|
next_input.push(clamp(
|
||||||
|
normalise(distance_from_gateway,
|
||||||
|
{ min: 0, max: 20000 },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
),
|
||||||
|
0, 1))
|
||||||
|
|
||||||
|
result.input.push(next_input);
|
||||||
|
|
||||||
|
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
|
||||||
|
|
||||||
|
result.output.push([
|
||||||
|
clamp(normalise(rssi.rssi,
|
||||||
|
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
), 0, 1)
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for(let reading of this.repo_reading.iterate_unreceived()) {
|
||||||
|
result.input.push(this.normalise_latlng(
|
||||||
|
reading.latitude,
|
||||||
|
reading.longitude
|
||||||
|
));
|
||||||
|
|
||||||
|
result.output.push([ 0 ]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
*fetch_input(gateway_id) {
|
*fetch_input(gateway_id) {
|
||||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id))
|
||||||
yield [
|
yield this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||||
normalise(rssi.latitude,
|
|
||||||
{ min: -90, max: +90 },
|
for(let rssi of this.repo_reading.iterate_unreceived())
|
||||||
{ min: 0, max: 1 }
|
yield this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||||
),
|
|
||||||
normalise(rssi.longitude,
|
|
||||||
{ min: -180, max: +180 },
|
|
||||||
{ min: 0, max: 1 }
|
|
||||||
)
|
|
||||||
];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*fetch_output(gateway_id) {
|
*fetch_output(gateway_id) {
|
||||||
|
@ -32,6 +84,10 @@ class DatasetFetcher {
|
||||||
), 0, 1)
|
), 0, 1)
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
// Yield 0 for every unreceived message, since we want to train it to predict a *terrible* signal where the gateway is not
|
||||||
|
for(let rssi of this.repo_reading.iterate_unreceived()) {
|
||||||
|
yield [ 0 ];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue