Try to add an additional input to the neural network, but it's not working.
This commit is contained in:
parent
9e93940225
commit
43f99d6b5d
7 changed files with 121 additions and 37 deletions
|
@ -1,7 +1,11 @@
|
|||
{
|
||||
"ecmaVersion": 8,
|
||||
"libs": [],
|
||||
"loadEagerly": [],
|
||||
"libs": [
|
||||
"browser"
|
||||
],
|
||||
"loadEagerly": [
|
||||
"server/**.mjs"
|
||||
],
|
||||
"dontLoad": [],
|
||||
"plugins": {
|
||||
"doc_comment": true
|
||||
|
|
|
@ -6,6 +6,13 @@ function normalise(value, { min : input_min, max: input_max }, { min : output_mi
|
|||
) + output_min
|
||||
}
|
||||
|
||||
/**
|
||||
* Clamps a value to fit within the specified bounds.
|
||||
* @param {number} value The number to clamp.
|
||||
* @param {number} min The minimum value the number can be.
|
||||
* @param {number} max The maximum value the number can be.
|
||||
* @return {number} The clamped number.
|
||||
*/
|
||||
function clamp(value, min, max) {
|
||||
if(value > max) return max;
|
||||
if(value < min) return min;
|
||||
|
|
|
@ -36,6 +36,10 @@ class GatewayRepo {
|
|||
return count > 0;
|
||||
}
|
||||
|
||||
get_by_id(id) {
|
||||
return this.db.prepare(`SELECT * FROM gateways WHERE id = :id`).get({ id });
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterates over all the gateways in the table.
|
||||
* TODO: Use Symbol.iterator here?
|
||||
|
|
|
@ -39,6 +39,10 @@ class ReadingRepo {
|
|||
return count > 0;
|
||||
}
|
||||
|
||||
iterate_unreceived() {
|
||||
return this.db.prepare(`SELECT * FROM readings WHERE data_rate IS NULL;`).iterate();
|
||||
}
|
||||
|
||||
iterate() {
|
||||
return this.db.prepare(`SELECT * FROM readings`).iterate();
|
||||
}
|
||||
|
|
|
@ -66,6 +66,8 @@ rssi_max = 0
|
|||
batch_size = 32
|
||||
# The number of epochs to train for.
|
||||
epochs = 5
|
||||
# The percentage, between 0 and 1, of data that should be set aside for validation.
|
||||
validation_split = 0.1
|
||||
|
||||
# The directory to output trained AIs to, relative to the repository root.
|
||||
output_directory = "app/ais/"
|
||||
|
|
|
@ -6,7 +6,8 @@ import fs from 'fs';
|
|||
import tf from '@tensorflow/tfjs-node-gpu';
|
||||
|
||||
class AITrainer {
|
||||
constructor({ settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
|
||||
constructor({ ansi, settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
|
||||
this.a = ansi;
|
||||
this.settings = settings;
|
||||
this.root_dir = root_dir;
|
||||
this.l = log;
|
||||
|
@ -17,9 +18,9 @@ class AITrainer {
|
|||
generate_model() {
|
||||
let model = tf.sequential();
|
||||
model.add(tf.layers.dense({
|
||||
units: 256, // 256 nodes
|
||||
units: 64, // 64 nodes
|
||||
activation: "sigmoid", // Sigmoid activation function
|
||||
inputShape: [2], // 2 inputs - lat and long
|
||||
inputShape: [3], // 3 inputs - lat, long, and distance from gateway
|
||||
}))
|
||||
model.add(tf.layers.dense({
|
||||
units: 1, // 1 output value - RSSI
|
||||
|
@ -82,30 +83,36 @@ class AITrainer {
|
|||
* @return {Promise} A promise that resolves when training and serialisation is complete.
|
||||
*/
|
||||
async train_gateway(gateway_id, destination_filename) {
|
||||
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
||||
let model = this.generate_model();
|
||||
|
||||
// TODO: Add samples here for locations that the gateway does NOT cover too
|
||||
let dataset_input = tf.data.generator(
|
||||
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||
// let dataset_input = tf.data.generator(
|
||||
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||
// );
|
||||
// let dataset_output = tf.data.generator(
|
||||
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
||||
// );
|
||||
//
|
||||
// let dataset = tf.data.zip({
|
||||
// xs: dataset_input,
|
||||
// ys: dataset_output
|
||||
// }).shuffle(this.settings.ai.batch_size * 4)
|
||||
// .batch(this.settings.ai.batch_size);
|
||||
//
|
||||
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
|
||||
|
||||
let result = await model.fit(
|
||||
tf.tensor(datasets.input),
|
||||
tf.tensor(datasets.output),
|
||||
{
|
||||
epochs: this.settings.ai.epochs,
|
||||
batchSize: this.settings.ai.batch_size,
|
||||
validationSplit: this.settings.ai.validation_split
|
||||
}
|
||||
);
|
||||
let dataset_output = tf.data.generator(
|
||||
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
||||
);
|
||||
|
||||
let dataset = tf.data.zip({
|
||||
xs: dataset_input,
|
||||
ys: dataset_output
|
||||
}).shuffle(this.settings.ai.batch_size * 4)
|
||||
.batch(this.settings.ai.batch_size);
|
||||
|
||||
|
||||
let result = await model.fitDataset(dataset, {
|
||||
epochs: this.settings.ai.epochs,
|
||||
batchSize: this.settings.ai.batch_size
|
||||
});
|
||||
|
||||
await model.save(`file://${destination_filename}`);
|
||||
console.log(result);
|
||||
// console.log(result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -1,26 +1,78 @@
|
|||
"use strict";
|
||||
|
||||
import haversine from 'haversine-distance';
|
||||
|
||||
import { normalise, clamp } from '../../common/Math.mjs';
|
||||
|
||||
class DatasetFetcher {
|
||||
constructor({ settings, RSSIRepo }) {
|
||||
constructor({ settings, GatewayRepo, RSSIRepo, ReadingRepo }) {
|
||||
this.settings = settings;
|
||||
this.repo_gateway = GatewayRepo;
|
||||
this.repo_rssi = RSSIRepo;
|
||||
this.repo_reading = ReadingRepo;
|
||||
}
|
||||
|
||||
normalise_latlng(lat, lng) {
|
||||
return [
|
||||
normalise(lat,
|
||||
{ min: -90, max: +90 },
|
||||
{ min: 0, max: 1 }
|
||||
),
|
||||
normalise(lng,
|
||||
{ min: -180, max: +180 },
|
||||
{ min: 0, max: 1 }
|
||||
)
|
||||
];
|
||||
}
|
||||
|
||||
fetch_all(gateway_id) {
|
||||
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
|
||||
|
||||
let result = {
|
||||
input: [],
|
||||
output: []
|
||||
};
|
||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||
let distance_from_gateway = haversine(gateway_location, rssi);
|
||||
|
||||
next_input.push(clamp(
|
||||
normalise(distance_from_gateway,
|
||||
{ min: 0, max: 20000 },
|
||||
{ min: 0, max: 1 }
|
||||
),
|
||||
0, 1))
|
||||
|
||||
result.input.push(next_input);
|
||||
|
||||
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
|
||||
|
||||
result.output.push([
|
||||
clamp(normalise(rssi.rssi,
|
||||
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||
{ min: 0, max: 1 }
|
||||
), 0, 1)
|
||||
]);
|
||||
}
|
||||
|
||||
for(let reading of this.repo_reading.iterate_unreceived()) {
|
||||
result.input.push(this.normalise_latlng(
|
||||
reading.latitude,
|
||||
reading.longitude
|
||||
));
|
||||
|
||||
result.output.push([ 0 ]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
*fetch_input(gateway_id) {
|
||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||
yield [
|
||||
normalise(rssi.latitude,
|
||||
{ min: -90, max: +90 },
|
||||
{ min: 0, max: 1 }
|
||||
),
|
||||
normalise(rssi.longitude,
|
||||
{ min: -180, max: +180 },
|
||||
{ min: 0, max: 1 }
|
||||
)
|
||||
];
|
||||
}
|
||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id))
|
||||
yield this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||
|
||||
for(let rssi of this.repo_reading.iterate_unreceived())
|
||||
yield this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||
}
|
||||
|
||||
*fetch_output(gateway_id) {
|
||||
|
@ -32,6 +84,10 @@ class DatasetFetcher {
|
|||
), 0, 1)
|
||||
];
|
||||
}
|
||||
// Yield 0 for every unreceived message, since we want to train it to predict a *terrible* signal where the gateway is not
|
||||
for(let rssi of this.repo_reading.iterate_unreceived()) {
|
||||
yield [ 0 ];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue