Finish refactoring to use Brain.js, but it's untested.
This commit is contained in:
parent
cce0761fed
commit
03c1cbb97f
3 changed files with 45 additions and 39 deletions
|
@ -2,13 +2,11 @@
|
||||||
|
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
|
|
||||||
import {
|
import brain from 'brain.js';
|
||||||
loadLayersModel as tf_loadLayersModel,
|
import haversine from 'haversine-distance';
|
||||||
tensor as tf_tensor,
|
|
||||||
setBackend as tf_setBackend
|
|
||||||
} from '@tensorflow/tfjs';
|
|
||||||
|
|
||||||
import { normalise } from '../../../common/Math.mjs';
|
import { normalise, clamp } from '../../../common/Math.mjs';
|
||||||
|
import GetFromUrl from '../Helpers/GetFromUrl.mjs';
|
||||||
|
|
||||||
class AIWrapper {
|
class AIWrapper {
|
||||||
constructor() {
|
constructor() {
|
||||||
|
@ -29,12 +27,16 @@ class AIWrapper {
|
||||||
console.log("Loading models");
|
console.log("Loading models");
|
||||||
|
|
||||||
// WebGL isn't available inside WebWorkers yet :-(
|
// WebGL isn't available inside WebWorkers yet :-(
|
||||||
tf_setBackend("cpu");
|
|
||||||
|
|
||||||
for(let gateway of this.index.index) {
|
for(let gateway of this.index.index) {
|
||||||
|
let net = new brain.NeuralNetwork();
|
||||||
|
net.fromJSON(
|
||||||
|
await GetFromUrl(`${path.dirname(self.location.href)}/${path.dirname(this.Config.ai_index_file)}/${gateway.filename}`)
|
||||||
|
);
|
||||||
|
|
||||||
this.gateways.set(
|
this.gateways.set(
|
||||||
gateway.id,
|
gateway.id,
|
||||||
await tf_loadLayersModel(`${path.dirname(self.location.href)}/${path.dirname(this.Config.ai_index_file)}/${gateway.id}/model.json`)
|
net
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
console.log("Model setup complete.");
|
console.log("Model setup complete.");
|
||||||
|
@ -54,13 +56,29 @@ class AIWrapper {
|
||||||
for(let lng = this.map_bounds.west; lng < this.map_bounds.east; lng += this.Config.step.lng) {
|
for(let lng = this.map_bounds.west; lng < this.map_bounds.east; lng += this.Config.step.lng) {
|
||||||
let max_predicted_rssi = -Infinity;
|
let max_predicted_rssi = -Infinity;
|
||||||
|
|
||||||
for(let [, ai] of this.gateways) {
|
for(let [gateway_id, ai] of this.gateways) {
|
||||||
let next_prediction = ai.predict(
|
let distance_from_gateway = haversine(
|
||||||
tf_tensor([ lat, lng ], [1, 2])
|
{ latitude: lat, longitude: lng },
|
||||||
).arraySync()[0][0];
|
this.gateways.get(gateway_id)
|
||||||
|
);
|
||||||
max_predicted_rssi = Math.max(
|
max_predicted_rssi = Math.max(
|
||||||
max_predicted_rssi,
|
max_predicted_rssi,
|
||||||
next_prediction
|
ai.run({
|
||||||
|
latitude: normalise(lat,
|
||||||
|
{ min: -90, max: +90 },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
),
|
||||||
|
longitude: normalise(lng,
|
||||||
|
{ min: -180, max: +180 },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
),
|
||||||
|
distance: clamp(
|
||||||
|
normalise(distance_from_gateway,
|
||||||
|
{ min: 0, max: 20000 },
|
||||||
|
{ min: 0, max: 1 }
|
||||||
|
),
|
||||||
|
0, 1)
|
||||||
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,19 +28,20 @@ class AITrainer {
|
||||||
async train_all() {
|
async train_all() {
|
||||||
let index = [];
|
let index = [];
|
||||||
for(let gateway of this.repo_gateway.iterate()) {
|
for(let gateway of this.repo_gateway.iterate()) {
|
||||||
let filename = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}`);
|
let filepath = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}.json`);
|
||||||
console.log(filename);
|
console.log(filepath);
|
||||||
|
|
||||||
if(!fs.existsSync(path.dirname(filename)))
|
if(!fs.existsSync(path.dirname(filepath)))
|
||||||
await fs.promises.mkdir(path.dirname(filename), { recursive: true });
|
await fs.promises.mkdir(path.dirname(filepath), { recursive: true });
|
||||||
|
|
||||||
if(!await this.train_gateway(gateway.id, filename)) {
|
if(!await this.train_gateway(gateway.id, filepath)) {
|
||||||
this.l.warn(`Warning: Failed to train AI for ${gateway.id}.`);
|
this.l.warn(`Warning: Failed to train AI for ${gateway.id}.`);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
index.push({
|
index.push({
|
||||||
id: gateway.id,
|
id: gateway.id,
|
||||||
|
filename: path.basename(filepath),
|
||||||
latitude: gateway.latitude,
|
latitude: gateway.latitude,
|
||||||
longitude: gateway.longitude
|
longitude: gateway.longitude
|
||||||
});
|
});
|
||||||
|
@ -72,22 +73,9 @@ class AITrainer {
|
||||||
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
||||||
let net = this.generate_neural_net();
|
let net = this.generate_neural_net();
|
||||||
|
|
||||||
// let dataset_input = tf.data.generator(
|
|
||||||
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
|
||||||
// );
|
|
||||||
// let dataset_output = tf.data.generator(
|
|
||||||
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// let dataset = tf.data.zip({
|
|
||||||
// xs: dataset_input,
|
|
||||||
// ys: dataset_output
|
|
||||||
// }).shuffle(this.settings.ai.batch_size * 4)
|
|
||||||
// .batch(this.settings.ai.batch_size);
|
|
||||||
//
|
|
||||||
let dataset = this.dataset_fetcher.fetch_all(gateway_id);
|
let dataset = this.dataset_fetcher.fetch_all(gateway_id);
|
||||||
|
|
||||||
let result = net.train(dataset, {
|
await net.trainAsync(dataset, {
|
||||||
iterations: this.settings.ai.epochs,
|
iterations: this.settings.ai.epochs,
|
||||||
errorThresh: this.settings.ai.error_threshold,
|
errorThresh: this.settings.ai.error_threshold,
|
||||||
|
|
||||||
|
@ -97,7 +85,7 @@ class AITrainer {
|
||||||
timeout: Infinity
|
timeout: Infinity
|
||||||
});
|
});
|
||||||
|
|
||||||
await model.save(`file://${destination_filename}`);
|
await fs.promises.writeFile(destination_filename, net.toJSON());
|
||||||
// console.log(result);
|
// console.log(result);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -13,16 +13,16 @@ class DatasetFetcher {
|
||||||
}
|
}
|
||||||
|
|
||||||
normalise_latlng(lat, lng) {
|
normalise_latlng(lat, lng) {
|
||||||
return [
|
return {
|
||||||
normalise(lat,
|
latitude: normalise(lat,
|
||||||
{ min: -90, max: +90 },
|
{ min: -90, max: +90 },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
),
|
),
|
||||||
normalise(lng,
|
longitude: normalise(lng,
|
||||||
{ min: -180, max: +180 },
|
{ min: -180, max: +180 },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
)
|
)
|
||||||
];
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch_all(gateway_id) {
|
fetch_all(gateway_id) {
|
||||||
|
@ -33,12 +33,12 @@ class DatasetFetcher {
|
||||||
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||||
let distance_from_gateway = haversine(gateway_location, rssi);
|
let distance_from_gateway = haversine(gateway_location, rssi);
|
||||||
|
|
||||||
next_input.push(clamp(
|
next_input.distance = clamp(
|
||||||
normalise(distance_from_gateway,
|
normalise(distance_from_gateway,
|
||||||
{ min: 0, max: 20000 },
|
{ min: 0, max: 20000 },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
),
|
),
|
||||||
0, 1))
|
0, 1);
|
||||||
|
|
||||||
result.input.push(next_input);
|
result.input.push(next_input);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue