Hook up a dataset importer for training the AIs, but it's untested.
Also, we don't have any code that actually does the training itself either yet.
This commit is contained in:
parent
f7e2d77daa
commit
6dbb6c3b87
6 changed files with 90 additions and 4 deletions
16
server/Helpers/Math.mjs
Normal file
16
server/Helpers/Math.mjs
Normal file
|
@ -0,0 +1,16 @@
|
|||
"use strict";
|
||||
|
||||
function normalise(value, { min : input_min, max: input_max }, { min : output_min, max: output_max }) {
|
||||
return (
|
||||
((value - input_min) / (input_max - input_min)) * (output_max - output_min)
|
||||
) + output_min
|
||||
}
|
||||
|
||||
function clamp(value, min, max) {
|
||||
if(value > max) return max;
|
||||
if(value < min) return min;
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
export { normalise, clamp };
|
|
@ -23,6 +23,18 @@ class RSSIRepo {
|
|||
}
|
||||
}
|
||||
|
||||
iterate_gateway(gateway_id) {
|
||||
return this.db.prepare(`SELECT
|
||||
rssis.*,
|
||||
readings.latitude,
|
||||
readings.longitude
|
||||
FROM rssis
|
||||
JOIN readings ON rssis.gateway_id = readings.id
|
||||
WHERE gateway_id = :gateway_id`).iterate({
|
||||
gateway_id
|
||||
});
|
||||
}
|
||||
|
||||
iterate() {
|
||||
return this.db.prepare(`SELECT * FROM rssis`).iterate();
|
||||
}
|
||||
|
|
|
@ -13,6 +13,9 @@ import MessageHandler from '../ttn-app-server/MessageHandler.mjs';
|
|||
|
||||
import DataProcessor from '../process-data/DataProcessor.mjs';
|
||||
|
||||
import AITrainer from '../train-ai/AITrainer.mjs';
|
||||
import DatasetFetcher from '../train-ai/DatasetFetcher.mjs';
|
||||
|
||||
import settings from './settings.mjs';
|
||||
import database_init from '../bootstrap/database_init.mjs';
|
||||
|
||||
|
@ -31,7 +34,11 @@ export default async function setup() {
|
|||
database: a.asFunction(database_init).singleton(),
|
||||
TTNAppServer: a.asClass(TTNAppServer),
|
||||
MessageHandler: a.asClass(MessageHandler),
|
||||
DataProcessor: a.asClass(DataProcessor)
|
||||
|
||||
DataProcessor: a.asClass(DataProcessor),
|
||||
|
||||
AITrainer: a.asClass(AITrainer),
|
||||
DatasetTrainer: a.asClass(DatasetTrainer)
|
||||
});
|
||||
|
||||
// Enable / disable colourising the output
|
||||
|
|
|
@ -54,6 +54,11 @@ devices = [
|
|||
[ai]
|
||||
# Settings relating to the training of the AI. Note that a number of these settings can also be specified by environment variables, to aid with fiddling with the parameters to find the right settings.
|
||||
|
||||
# Min / max dataset values when training the AI, since neural networks only take values between 0 and 1.
|
||||
# Note that changing these means that you've got to retrain the AIs all over again!
|
||||
rssi_min = -150
|
||||
rssi_max = 0
|
||||
|
||||
[logging]
|
||||
# The format the date displayed when logging things should take.
|
||||
# Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely))
|
||||
|
|
|
@ -3,8 +3,10 @@
|
|||
import tf from '@tensorflow/tfjs-node-gpu';
|
||||
|
||||
class AITrainer {
|
||||
constructor({ settings }) {
|
||||
constructor({ settings, GatewayRepo, DatasetFetcher }) {
|
||||
this.settings = settings;
|
||||
this.dataset_fetcher = DatasetFetcher;
|
||||
this.repo_gateway = GatewayRepo;
|
||||
this.model = this.generate_model();
|
||||
}
|
||||
|
||||
|
@ -29,8 +31,15 @@ class AITrainer {
|
|||
return model;
|
||||
}
|
||||
|
||||
train() {
|
||||
async train() {
|
||||
for(let gateway of this.repo_gateway.iterate()) {
|
||||
let dataset = this.dataset_fetcher.fetch(gateway.id);
|
||||
await this.train_dataset(dataset);
|
||||
}
|
||||
}
|
||||
|
||||
async train_dataset(dataset) {
|
||||
// TODO: Fill this in
|
||||
}
|
||||
}
|
||||
|
||||
|
|
37
server/train-ai/DatasetFetcher.mjs
Normal file
37
server/train-ai/DatasetFetcher.mjs
Normal file
|
@ -0,0 +1,37 @@
|
|||
"use strict";
|
||||
|
||||
import { normalise, clamp } from '../Helpers/Math.mjs';
|
||||
|
||||
class DatasetFetcher {
|
||||
constructor({ settings, RSSIRepo }) {
|
||||
this.settings = settings;
|
||||
this.repo_rssi = RSSIRepo;
|
||||
}
|
||||
|
||||
fetch(gateway_id) {
|
||||
let result = [];
|
||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id) {
|
||||
result.push({
|
||||
input: [
|
||||
normalise(rssi.latitude,
|
||||
{ min: -90, max: +90 },
|
||||
{ min: 0, max: 1 }
|
||||
),
|
||||
normalise(rssi.longitude,
|
||||
{ min: -180, max: +180 },
|
||||
{ min: 0, max: 1 }
|
||||
)
|
||||
],
|
||||
output: [
|
||||
clamp(normalise(rssis.rssi,
|
||||
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||
{ min: 0, max: 1 }
|
||||
), 0, 1)
|
||||
]
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
export default DatasetFetcher;
|
Loading…
Reference in a new issue