Continue brain.js refactoring, but it's still not finished yet.

This commit is contained in:
Starbeamrainbowlabs 2019-07-29 17:02:34 +01:00
parent 2a81ac792e
commit cce0761fed
3 changed files with 29 additions and 23 deletions

View file

@ -65,8 +65,14 @@ network_arch = [ 64 ]
# The number of epochs to train for. # The number of epochs to train for.
epochs = 1000 epochs = 1000
# The percentage, between 0 and 1, of data that should be set aside for validation. # Cut training short if the mean squared error drops below
validation_split = 0.1 # this value
error_threshold = 0.0001
# The learning rate that the neural networks should learn at
learning_rate = 0.1
# The momentum term to use when learning
momentum = 0.1
# The directory to output trained AIs to, relative to the repository root. # The directory to output trained AIs to, relative to the repository root.
output_directory = "app/ais/" output_directory = "app/ais/"

View file

@ -70,7 +70,7 @@ class AITrainer {
*/ */
async train_gateway(gateway_id, destination_filename) { async train_gateway(gateway_id, destination_filename) {
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`); this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
let model = this.generate_neural_net(); let net = this.generate_neural_net();
// let dataset_input = tf.data.generator( // let dataset_input = tf.data.generator(
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id) // this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
@ -85,17 +85,17 @@ class AITrainer {
// }).shuffle(this.settings.ai.batch_size * 4) // }).shuffle(this.settings.ai.batch_size * 4)
// .batch(this.settings.ai.batch_size); // .batch(this.settings.ai.batch_size);
// //
let datasets = this.dataset_fetcher.fetch_all(gateway_id); let dataset = this.dataset_fetcher.fetch_all(gateway_id);
let result = await model.fit( let result = net.train(dataset, {
tf.tensor(datasets.input), iterations: this.settings.ai.epochs,
tf.tensor(datasets.output), errorThresh: this.settings.ai.error_threshold,
{
epochs: this.settings.ai.epochs, learningRate: this.settings.ai.learning_rate,
batchSize: this.settings.ai.batch_size, momentum: this.settings.ai.momentum,
validationSplit: this.settings.ai.validation_split
} timeout: Infinity
); });
await model.save(`file://${destination_filename}`); await model.save(`file://${destination_filename}`);
// console.log(result); // console.log(result);

View file

@ -28,10 +28,7 @@ class DatasetFetcher {
fetch_all(gateway_id) { fetch_all(gateway_id) {
let gateway_location = this.repo_gateway.get_by_id(gateway_id); let gateway_location = this.repo_gateway.get_by_id(gateway_id);
let result = { let result = [];
input: [],
output: []
};
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) { for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude); let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
let distance_from_gateway = haversine(gateway_location, rssi); let distance_from_gateway = haversine(gateway_location, rssi);
@ -47,12 +44,15 @@ class DatasetFetcher {
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`); console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
result.output.push([ let next_output = clamp(normalise(rssi.rssi,
clamp(normalise(rssi.rssi,
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max }, { min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
{ min: 0, max: 1 } { min: 0, max: 1 }
), 0, 1) ), 0, 1);
]);
result.push({
input: next_input,
output: next_output
});
} }
for(let reading of this.repo_reading.iterate_unreceived()) { for(let reading of this.repo_reading.iterate_unreceived()) {