LoRaWAN-Signal-Mapping/server/train-ai/AITrainer.mjs

68 lines
1.7 KiB
JavaScript
Raw Normal View History

"use strict";
import tf from '@tensorflow/tfjs-node-gpu';
class AITrainer {
constructor({ settings, log, GatewayRepo, DatasetFetcher }) {
this.settings = settings;
this.l = log;
this.dataset_fetcher = DatasetFetcher;
this.repo_gateway = GatewayRepo;
this.model = this.generate_model();
}
generate_model() {
let model = tf.sequential();
model.add(tf.layers.dense({
units: 256, // 256 nodes
activation: "sigmoid", // Sigmoid activation function
inputShape: [2], // 2 inputs - lat and long
}))
model.add(tf.layers.dense({
units: 1, // 1 output value - RSSI
activation: "sigmoid" // The example code uses softmax, but this is generally best used for classification tasks
}));
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.absoluteDifference,
metrics: [ tf.metrics.meanSquaredError ]
});
this.l.log(`Model:`);
model.summary();
return model;
}
async train_all() {
for(let gateway of this.repo_gateway.iterate()) {
await this.train_gateway(gateway.id);
}
}
async train_gateway(gateway_id) {
let dataset_input = tf.data.generator(
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
);
let dataset_output = tf.data.generator(
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
);
let dataset = tf.data.zip({
xs: dataset_input,
ys: dataset_output
}).shuffle(this.settings.ai.batch_size * 4)
.batch(this.settings.ai.batch_size);
let result = await this.model.fitDataset(dataset, {
epochs: this.settings.ai.epochs,
batchSize: this.settings.ai.batch_size
});
console.log(result);
}
}
export default AITrainer;