|
|
|
@ -6,7 +6,8 @@ import fs from 'fs'; |
|
|
|
|
import tf from '@tensorflow/tfjs-node-gpu'; |
|
|
|
|
|
|
|
|
|
class AITrainer { |
|
|
|
|
constructor({ settings, log, root_dir, GatewayRepo, DatasetFetcher }) { |
|
|
|
|
constructor({ ansi, settings, log, root_dir, GatewayRepo, DatasetFetcher }) { |
|
|
|
|
this.a = ansi; |
|
|
|
|
this.settings = settings; |
|
|
|
|
this.root_dir = root_dir; |
|
|
|
|
this.l = log; |
|
|
|
@ -17,9 +18,9 @@ class AITrainer { |
|
|
|
|
generate_model() { |
|
|
|
|
let model = tf.sequential(); |
|
|
|
|
model.add(tf.layers.dense({ |
|
|
|
|
units: 256, // 256 nodes
|
|
|
|
|
units: 64, // 64 nodes
|
|
|
|
|
activation: "sigmoid", // Sigmoid activation function
|
|
|
|
|
inputShape: [2], // 2 inputs - lat and long
|
|
|
|
|
inputShape: [3], // 3 inputs - lat, long, and distance from gateway
|
|
|
|
|
})) |
|
|
|
|
model.add(tf.layers.dense({ |
|
|
|
|
units: 1, // 1 output value - RSSI
|
|
|
|
@ -82,30 +83,36 @@ class AITrainer { |
|
|
|
|
* @return {Promise} A promise that resolves when training and serialisation is complete. |
|
|
|
|
*/ |
|
|
|
|
async train_gateway(gateway_id, destination_filename) { |
|
|
|
|
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`); |
|
|
|
|
let model = this.generate_model(); |
|
|
|
|
|
|
|
|
|
// TODO: Add samples here for locations that the gateway does NOT cover too
|
|
|
|
|
let dataset_input = tf.data.generator( |
|
|
|
|
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id) |
|
|
|
|
); |
|
|
|
|
let dataset_output = tf.data.generator( |
|
|
|
|
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id) |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
let dataset = tf.data.zip({ |
|
|
|
|
xs: dataset_input, |
|
|
|
|
ys: dataset_output |
|
|
|
|
}).shuffle(this.settings.ai.batch_size * 4) |
|
|
|
|
.batch(this.settings.ai.batch_size); |
|
|
|
|
|
|
|
|
|
// let dataset_input = tf.data.generator(
|
|
|
|
|
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
|
|
|
|
// );
|
|
|
|
|
// let dataset_output = tf.data.generator(
|
|
|
|
|
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
|
|
|
|
// );
|
|
|
|
|
//
|
|
|
|
|
// let dataset = tf.data.zip({
|
|
|
|
|
// xs: dataset_input,
|
|
|
|
|
// ys: dataset_output
|
|
|
|
|
// }).shuffle(this.settings.ai.batch_size * 4)
|
|
|
|
|
// .batch(this.settings.ai.batch_size);
|
|
|
|
|
//
|
|
|
|
|
let datasets = this.dataset_fetcher.fetch_all(gateway_id); |
|
|
|
|
|
|
|
|
|
let result = await model.fitDataset(dataset, { |
|
|
|
|
epochs: this.settings.ai.epochs, |
|
|
|
|
batchSize: this.settings.ai.batch_size |
|
|
|
|
}); |
|
|
|
|
let result = await model.fit( |
|
|
|
|
tf.tensor(datasets.input), |
|
|
|
|
tf.tensor(datasets.output), |
|
|
|
|
{ |
|
|
|
|
epochs: this.settings.ai.epochs, |
|
|
|
|
batchSize: this.settings.ai.batch_size, |
|
|
|
|
validationSplit: this.settings.ai.validation_split |
|
|
|
|
} |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
await model.save(`file://${destination_filename}`); |
|
|
|
|
console.log(result); |
|
|
|
|
// console.log(result);
|
|
|
|
|
|
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|