LoRaWAN-Signal-Mapping/server/train-ai/AITrainer.mjs

137 lines
3.5 KiB
JavaScript
Raw Normal View History

"use strict";
import path from 'path';
import fs from 'fs';
import brain from 'brain.js';
class AITrainer {
constructor({ ansi, settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
this.a = ansi;
this.settings = settings;
this.root_dir = root_dir;
this.l = log;
this.dataset_fetcher = DatasetFetcher;
this.repo_gateway = GatewayRepo;
}
async train_unified() {
let filepath = path.join(
this.root_dir,
"..",
this.settings.ai.output_directory,
"ai.json"
);
if(!fs.existsSync(path.dirname(filepath)))
await fs.promises.mkdir(path.dirname(filepath), { recursive: true });
let training_result = await this.train_gateway(null, filepath);
await fs.promises.writeFile(
path.join(
path.dirname(this.root_dir),
this.settings.ai.output_directory,
"index.json"
),
JSON.stringify({
properties: {
rssi_min: -150,
rssi_max: 0
},
index: [{
id: "Unified AI (not a real gateway)",
filename: "ai.json",
latitude: 0, longitude: 0,
net_settings: training_result.net_settings
}]
}, null, "\t")
);
}
async train_all() {
2019-07-22 11:42:04 +00:00
let index = [];
for(let gateway of this.repo_gateway.iterate()) {
let filepath = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}.json`);
if(!fs.existsSync(path.dirname(filepath)))
await fs.promises.mkdir(path.dirname(filepath), { recursive: true });
let result = await this.train_gateway(gateway.id, filepath);
if(!result || result.success === false) {
this.l.warn(`Warning: Failed to train AI for ${gateway.id}.`);
continue;
}
this.l.log(`Saved to ${filepath}`);
2019-07-22 11:42:04 +00:00
index.push({
id: gateway.id,
filename: path.basename(filepath),
latitude: gateway.latitude,
longitude: gateway.longitude,
net_settings: result.net_settings
2019-07-22 11:42:04 +00:00
});
}
await fs.promises.writeFile(
path.join(
path.dirname(this.root_dir),
this.settings.ai.output_directory,
"index.json"
),
2019-07-23 14:45:29 +00:00
JSON.stringify({
properties: {
2019-07-30 15:40:12 +00:00
rssi_min: -150,
rssi_max: 0
2019-07-23 14:45:29 +00:00
},
index
})
);
}
/**
* Trains an AI to predict the coverage of a specific gateway.
* @param {string} gateway_id The id of the gateway to train an AI for.
* @param {string} destination_filename The absolute path to the file to serialise the trained to. Required because we can't serialise and return a TensorFlow model, it has to be sent somewhere because the API is backwards and upside-down :-/
* @return {Promise} A promise that resolves when training and serialisation is complete.
*/
async train_gateway(gateway_id, destination_filename) {
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
2019-08-06 11:10:47 +00:00
// Create the neural network
let net_settings = {
hiddenLayers: this.settings.ai.network_arch,
activation: "sigmoid"
};
let net = new brain.NeuralNetwork(net_settings);
2019-07-23 14:14:50 +00:00
2019-08-06 11:10:47 +00:00
// Fetch the dataset
let dataset = this.dataset_fetcher.fetch_all(gateway_id);
await net.trainAsync(dataset, {
iterations: this.settings.ai.epochs,
errorThresh: this.settings.ai.error_threshold,
learningRate: this.settings.ai.learning_rate,
momentum: this.settings.ai.momentum,
log: (log_line) => this.l.log(`[brain.js] ${log_line}`),
2019-07-30 14:28:31 +00:00
logPeriod: 50,
timeout: Infinity
});
await fs.promises.writeFile(destination_filename, JSON.stringify(net.toJSON()), null, "\t");
// console.log(result);
return {
success: true,
net_settings
};
}
}
export default AITrainer;