LoRaWAN-Signal-Mapping/server/train-ai/AITrainer.mjs

121 lines
3.4 KiB
JavaScript

"use strict";
import path from 'path';
import fs from 'fs';
import tf from '@tensorflow/tfjs-node-gpu';
class AITrainer {
constructor({ ansi, settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
this.a = ansi;
this.settings = settings;
this.root_dir = root_dir;
this.l = log;
this.dataset_fetcher = DatasetFetcher;
this.repo_gateway = GatewayRepo;
}
generate_model() {
let model = tf.sequential();
model.add(tf.layers.dense({
units: 64, // 64 nodes
activation: "sigmoid", // Sigmoid activation function
inputShape: [3], // 3 inputs - lat, long, and distance from gateway
}))
model.add(tf.layers.dense({
units: 1, // 1 output value - RSSI
activation: "sigmoid" // The example code uses softmax, but this is generally best used for classification tasks
}));
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.absoluteDifference,
metrics: [ tf.metrics.meanSquaredError ]
});
this.l.log(`Model:`);
model.summary();
return model;
}
async train_all() {
let index = [];
for(let gateway of this.repo_gateway.iterate()) {
let filename = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}`);
console.log(filename);
if(!fs.existsSync(path.dirname(filename)))
await fs.promises.mkdir(path.dirname(filename), { recursive: true });
if(!await this.train_gateway(gateway.id, filename)) {
this.l.warn(`Warning: Failed to train AI for ${gateway.id}.`);
continue;
}
index.push({
id: gateway.id,
latitude: gateway.latitude,
longitude: gateway.longitude
});
}
await fs.promises.writeFile(
path.join(
path.dirname(this.root_dir),
this.settings.ai.output_directory,
"index.json"
),
JSON.stringify({
properties: {
rssi_min: this.settings.ai.rssi_min,
rssi_max: this.settings.ai.rssi_max
},
index
})
);
}
/**
* Trains an AI to predict the coverage of a specific gateway.
* @param {string} gateway_id The id of the gateway to train an AI for.
* @param {string} destination_filename The absolute path to the file to serialise the trained to. Required because we can't serialise and return a TensorFlow model, it has to be sent somewhere because the API is backwards and upside-down :-/
* @return {Promise} A promise that resolves when training and serialisation is complete.
*/
async train_gateway(gateway_id, destination_filename) {
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
let model = this.generate_model();
// let dataset_input = tf.data.generator(
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
// );
// let dataset_output = tf.data.generator(
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
// );
//
// let dataset = tf.data.zip({
// xs: dataset_input,
// ys: dataset_output
// }).shuffle(this.settings.ai.batch_size * 4)
// .batch(this.settings.ai.batch_size);
//
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
let result = await model.fit(
tf.tensor(datasets.input),
tf.tensor(datasets.output),
{
epochs: this.settings.ai.epochs,
batchSize: this.settings.ai.batch_size,
validationSplit: this.settings.ai.validation_split
}
);
await model.save(`file://${destination_filename}`);
// console.log(result);
return true;
}
}
export default AITrainer;