diff --git a/server/Helpers/Math.mjs b/server/Helpers/Math.mjs new file mode 100644 index 0000000..d94b1a7 --- /dev/null +++ b/server/Helpers/Math.mjs @@ -0,0 +1,16 @@ +"use strict"; + +function normalise(value, { min : input_min, max: input_max }, { min : output_min, max: output_max }) { + return ( + ((value - input_min) / (input_max - input_min)) * (output_max - output_min) + ) + output_min +} + +function clamp(value, min, max) { + if(value > max) return max; + if(value < min) return min; + return value; +} + + +export { normalise, clamp }; diff --git a/server/Repos.SQLite/RSSIRepo.mjs b/server/Repos.SQLite/RSSIRepo.mjs index 88b53dc..17f8728 100644 --- a/server/Repos.SQLite/RSSIRepo.mjs +++ b/server/Repos.SQLite/RSSIRepo.mjs @@ -23,6 +23,18 @@ class RSSIRepo { } } + iterate_gateway(gateway_id) { + return this.db.prepare(`SELECT + rssis.*, + readings.latitude, + readings.longitude + FROM rssis + JOIN readings ON rssis.gateway_id = readings.id + WHERE gateway_id = :gateway_id`).iterate({ + gateway_id + }); + } + iterate() { return this.db.prepare(`SELECT * FROM rssis`).iterate(); } diff --git a/server/bootstrap/container.mjs b/server/bootstrap/container.mjs index f4122d9..5f6ade7 100644 --- a/server/bootstrap/container.mjs +++ b/server/bootstrap/container.mjs @@ -13,6 +13,9 @@ import MessageHandler from '../ttn-app-server/MessageHandler.mjs'; import DataProcessor from '../process-data/DataProcessor.mjs'; +import AITrainer from '../train-ai/AITrainer.mjs'; +import DatasetFetcher from '../train-ai/DatasetFetcher.mjs'; + import settings from './settings.mjs'; import database_init from '../bootstrap/database_init.mjs'; @@ -31,7 +34,11 @@ export default async function setup() { database: a.asFunction(database_init).singleton(), TTNAppServer: a.asClass(TTNAppServer), MessageHandler: a.asClass(MessageHandler), - DataProcessor: a.asClass(DataProcessor) + + DataProcessor: a.asClass(DataProcessor), + + AITrainer: a.asClass(AITrainer), + DatasetTrainer: a.asClass(DatasetTrainer) }); // Enable / disable colourising the output diff --git a/server/settings.default.toml b/server/settings.default.toml index 25ae495..76cebe0 100644 --- a/server/settings.default.toml +++ b/server/settings.default.toml @@ -54,6 +54,11 @@ devices = [ [ai] # Settings relating to the training of the AI. Note that a number of these settings can also be specified by environment variables, to aid with fiddling with the parameters to find the right settings. +# Min / max dataset values when training the AI, since neural networks only take values between 0 and 1. +# Note that changing these means that you've got to retrain the AIs all over again! +rssi_min = -150 +rssi_max = 0 + [logging] # The format the date displayed when logging things should take. # Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely)) diff --git a/server/train-ai/AITrainer.mjs b/server/train-ai/AITrainer.mjs index 889638d..e056b28 100644 --- a/server/train-ai/AITrainer.mjs +++ b/server/train-ai/AITrainer.mjs @@ -3,8 +3,10 @@ import tf from '@tensorflow/tfjs-node-gpu'; class AITrainer { - constructor({ settings }) { + constructor({ settings, GatewayRepo, DatasetFetcher }) { this.settings = settings; + this.dataset_fetcher = DatasetFetcher; + this.repo_gateway = GatewayRepo; this.model = this.generate_model(); } @@ -29,8 +31,15 @@ class AITrainer { return model; } - train() { - + async train() { + for(let gateway of this.repo_gateway.iterate()) { + let dataset = this.dataset_fetcher.fetch(gateway.id); + await this.train_dataset(dataset); + } + } + + async train_dataset(dataset) { + // TODO: Fill this in } } diff --git a/server/train-ai/DatasetFetcher.mjs b/server/train-ai/DatasetFetcher.mjs new file mode 100644 index 0000000..bd843fc --- /dev/null +++ b/server/train-ai/DatasetFetcher.mjs @@ -0,0 +1,37 @@ +"use strict"; + +import { normalise, clamp } from '../Helpers/Math.mjs'; + +class DatasetFetcher { + constructor({ settings, RSSIRepo }) { + this.settings = settings; + this.repo_rssi = RSSIRepo; + } + + fetch(gateway_id) { + let result = []; + for(let rssi of this.repo_rssi.iterate_gateway(gateway_id) { + result.push({ + input: [ + normalise(rssi.latitude, + { min: -90, max: +90 }, + { min: 0, max: 1 } + ), + normalise(rssi.longitude, + { min: -180, max: +180 }, + { min: 0, max: 1 } + ) + ], + output: [ + clamp(normalise(rssis.rssi, + { min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max }, + { min: 0, max: 1 } + ), 0, 1) + ] + }); + } + return result; + } +} + +export default DatasetFetcher;