2019-07-17 13:16:24 +00:00
|
|
|
"use strict";
|
|
|
|
|
|
|
|
import tf from '@tensorflow/tfjs-node-gpu';
|
|
|
|
|
|
|
|
class AITrainer {
|
2019-07-18 16:22:37 +00:00
|
|
|
constructor({ settings, log, GatewayRepo, DatasetFetcher }) {
|
2019-07-17 13:16:24 +00:00
|
|
|
this.settings = settings;
|
2019-07-18 16:22:37 +00:00
|
|
|
this.l = log;
|
2019-07-17 14:15:31 +00:00
|
|
|
this.dataset_fetcher = DatasetFetcher;
|
|
|
|
this.repo_gateway = GatewayRepo;
|
2019-07-17 13:16:24 +00:00
|
|
|
this.model = this.generate_model();
|
|
|
|
}
|
|
|
|
|
|
|
|
generate_model() {
|
|
|
|
let model = tf.sequential();
|
|
|
|
model.add(tf.layers.dense({
|
|
|
|
units: 256, // 256 nodes
|
|
|
|
activation: "sigmoid", // Sigmoid activation function
|
2019-07-18 16:22:37 +00:00
|
|
|
inputShape: [2], // 2 inputs - lat and long
|
2019-07-17 13:16:24 +00:00
|
|
|
}))
|
|
|
|
model.add(tf.layers.dense({
|
|
|
|
units: 1, // 1 output value - RSSI
|
|
|
|
activation: "sigmoid" // The example code uses softmax, but this is generally best used for classification tasks
|
|
|
|
}));
|
|
|
|
|
|
|
|
model.compile({
|
|
|
|
optimizer: tf.train.adam(),
|
2019-07-18 16:22:37 +00:00
|
|
|
loss: tf.losses.absoluteDifference,
|
|
|
|
metrics: [ tf.metrics.meanSquaredError ]
|
2019-07-17 13:16:24 +00:00
|
|
|
});
|
|
|
|
|
2019-07-18 16:22:37 +00:00
|
|
|
this.l.log(`Model:`);
|
|
|
|
model.summary();
|
|
|
|
|
2019-07-17 13:16:24 +00:00
|
|
|
return model;
|
|
|
|
}
|
|
|
|
|
2019-07-18 15:34:25 +00:00
|
|
|
async train_all() {
|
2019-07-17 14:15:31 +00:00
|
|
|
for(let gateway of this.repo_gateway.iterate()) {
|
2019-07-18 16:22:37 +00:00
|
|
|
await this.train_gateway(gateway.id);
|
2019-07-17 14:15:31 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-07-18 15:34:25 +00:00
|
|
|
async train_gateway(gateway_id) {
|
|
|
|
let dataset_input = tf.data.generator(
|
2019-07-18 16:22:37 +00:00
|
|
|
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
2019-07-18 15:34:25 +00:00
|
|
|
);
|
|
|
|
let dataset_output = tf.data.generator(
|
2019-07-18 16:22:37 +00:00
|
|
|
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
2019-07-18 15:34:25 +00:00
|
|
|
);
|
|
|
|
|
|
|
|
let dataset = tf.data.zip({
|
|
|
|
xs: dataset_input,
|
|
|
|
ys: dataset_output
|
2019-07-18 16:22:37 +00:00
|
|
|
}).shuffle(this.settings.ai.batch_size * 4)
|
2019-07-18 15:34:25 +00:00
|
|
|
.batch(this.settings.ai.batch_size);
|
|
|
|
|
2019-07-18 16:22:37 +00:00
|
|
|
|
2019-07-18 15:34:25 +00:00
|
|
|
let result = await this.model.fitDataset(dataset, {
|
|
|
|
epochs: this.settings.ai.epochs,
|
|
|
|
batchSize: this.settings.ai.batch_size
|
|
|
|
});
|
|
|
|
console.log(result);
|
2019-07-17 13:16:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
export default AITrainer;
|