Finish refactoring to use Brain.js, but it's untested.

This commit is contained in:
Starbeamrainbowlabs 2019-07-29 18:06:50 +01:00
parent cce0761fed
commit 03c1cbb97f
3 changed files with 45 additions and 39 deletions

View file

@ -2,13 +2,11 @@
import path from 'path'; import path from 'path';
import { import brain from 'brain.js';
loadLayersModel as tf_loadLayersModel, import haversine from 'haversine-distance';
tensor as tf_tensor,
setBackend as tf_setBackend
} from '@tensorflow/tfjs';
import { normalise } from '../../../common/Math.mjs'; import { normalise, clamp } from '../../../common/Math.mjs';
import GetFromUrl from '../Helpers/GetFromUrl.mjs';
class AIWrapper { class AIWrapper {
constructor() { constructor() {
@ -29,12 +27,16 @@ class AIWrapper {
console.log("Loading models"); console.log("Loading models");
// WebGL isn't available inside WebWorkers yet :-( // WebGL isn't available inside WebWorkers yet :-(
tf_setBackend("cpu");
for(let gateway of this.index.index) { for(let gateway of this.index.index) {
let net = new brain.NeuralNetwork();
net.fromJSON(
await GetFromUrl(`${path.dirname(self.location.href)}/${path.dirname(this.Config.ai_index_file)}/${gateway.filename}`)
);
this.gateways.set( this.gateways.set(
gateway.id, gateway.id,
await tf_loadLayersModel(`${path.dirname(self.location.href)}/${path.dirname(this.Config.ai_index_file)}/${gateway.id}/model.json`) net
); );
} }
console.log("Model setup complete."); console.log("Model setup complete.");
@ -54,13 +56,29 @@ class AIWrapper {
for(let lng = this.map_bounds.west; lng < this.map_bounds.east; lng += this.Config.step.lng) { for(let lng = this.map_bounds.west; lng < this.map_bounds.east; lng += this.Config.step.lng) {
let max_predicted_rssi = -Infinity; let max_predicted_rssi = -Infinity;
for(let [, ai] of this.gateways) { for(let [gateway_id, ai] of this.gateways) {
let next_prediction = ai.predict( let distance_from_gateway = haversine(
tf_tensor([ lat, lng ], [1, 2]) { latitude: lat, longitude: lng },
).arraySync()[0][0]; this.gateways.get(gateway_id)
);
max_predicted_rssi = Math.max( max_predicted_rssi = Math.max(
max_predicted_rssi, max_predicted_rssi,
next_prediction ai.run({
latitude: normalise(lat,
{ min: -90, max: +90 },
{ min: 0, max: 1 }
),
longitude: normalise(lng,
{ min: -180, max: +180 },
{ min: 0, max: 1 }
),
distance: clamp(
normalise(distance_from_gateway,
{ min: 0, max: 20000 },
{ min: 0, max: 1 }
),
0, 1)
})
); );
} }

View file

@ -28,19 +28,20 @@ class AITrainer {
async train_all() { async train_all() {
let index = []; let index = [];
for(let gateway of this.repo_gateway.iterate()) { for(let gateway of this.repo_gateway.iterate()) {
let filename = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}`); let filepath = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}.json`);
console.log(filename); console.log(filepath);
if(!fs.existsSync(path.dirname(filename))) if(!fs.existsSync(path.dirname(filepath)))
await fs.promises.mkdir(path.dirname(filename), { recursive: true }); await fs.promises.mkdir(path.dirname(filepath), { recursive: true });
if(!await this.train_gateway(gateway.id, filename)) { if(!await this.train_gateway(gateway.id, filepath)) {
this.l.warn(`Warning: Failed to train AI for ${gateway.id}.`); this.l.warn(`Warning: Failed to train AI for ${gateway.id}.`);
continue; continue;
} }
index.push({ index.push({
id: gateway.id, id: gateway.id,
filename: path.basename(filepath),
latitude: gateway.latitude, latitude: gateway.latitude,
longitude: gateway.longitude longitude: gateway.longitude
}); });
@ -72,22 +73,9 @@ class AITrainer {
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`); this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
let net = this.generate_neural_net(); let net = this.generate_neural_net();
// let dataset_input = tf.data.generator(
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
// );
// let dataset_output = tf.data.generator(
// this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
// );
//
// let dataset = tf.data.zip({
// xs: dataset_input,
// ys: dataset_output
// }).shuffle(this.settings.ai.batch_size * 4)
// .batch(this.settings.ai.batch_size);
//
let dataset = this.dataset_fetcher.fetch_all(gateway_id); let dataset = this.dataset_fetcher.fetch_all(gateway_id);
let result = net.train(dataset, { await net.trainAsync(dataset, {
iterations: this.settings.ai.epochs, iterations: this.settings.ai.epochs,
errorThresh: this.settings.ai.error_threshold, errorThresh: this.settings.ai.error_threshold,
@ -97,7 +85,7 @@ class AITrainer {
timeout: Infinity timeout: Infinity
}); });
await model.save(`file://${destination_filename}`); await fs.promises.writeFile(destination_filename, net.toJSON());
// console.log(result); // console.log(result);
return true; return true;

View file

@ -13,16 +13,16 @@ class DatasetFetcher {
} }
normalise_latlng(lat, lng) { normalise_latlng(lat, lng) {
return [ return {
normalise(lat, latitude: normalise(lat,
{ min: -90, max: +90 }, { min: -90, max: +90 },
{ min: 0, max: 1 } { min: 0, max: 1 }
), ),
normalise(lng, longitude: normalise(lng,
{ min: -180, max: +180 }, { min: -180, max: +180 },
{ min: 0, max: 1 } { min: 0, max: 1 }
) )
]; };
} }
fetch_all(gateway_id) { fetch_all(gateway_id) {
@ -33,12 +33,12 @@ class DatasetFetcher {
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude); let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
let distance_from_gateway = haversine(gateway_location, rssi); let distance_from_gateway = haversine(gateway_location, rssi);
next_input.push(clamp( next_input.distance = clamp(
normalise(distance_from_gateway, normalise(distance_from_gateway,
{ min: 0, max: 20000 }, { min: 0, max: 20000 },
{ min: 0, max: 1 } { min: 0, max: 1 }
), ),
0, 1)) 0, 1);
result.input.push(next_input); result.input.push(next_input);