Continue brain.js refactoring, but it's still not finished yet.

This commit is contained in:
Starbeamrainbowlabs 2019-07-29 17:02:34 +01:00
parent 2a81ac792e
commit cce0761fed
3 changed files with 29 additions and 23 deletions

View file

@ -65,8 +65,14 @@ network_arch = [ 64 ]
# The number of epochs to train for.
epochs = 1000
# The percentage, between 0 and 1, of data that should be set aside for validation.
validation_split = 0.1
# Cut training short if the mean squared error drops below
# this value
error_threshold = 0.0001
# The learning rate that the neural networks should learn at
learning_rate = 0.1
# The momentum term to use when learning
momentum = 0.1
# The directory to output trained AIs to, relative to the repository root.
output_directory = "app/ais/"

View file

@ -70,7 +70,7 @@ class AITrainer {
*/
async train_gateway(gateway_id, destination_filename) {
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
let model = this.generate_neural_net();
let net = this.generate_neural_net();
// let dataset_input = tf.data.generator(
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
@ -85,17 +85,17 @@ class AITrainer {
// }).shuffle(this.settings.ai.batch_size * 4)
// .batch(this.settings.ai.batch_size);
//
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
let dataset = this.dataset_fetcher.fetch_all(gateway_id);
let result = await model.fit(
tf.tensor(datasets.input),
tf.tensor(datasets.output),
{
epochs: this.settings.ai.epochs,
batchSize: this.settings.ai.batch_size,
validationSplit: this.settings.ai.validation_split
}
);
let result = net.train(dataset, {
iterations: this.settings.ai.epochs,
errorThresh: this.settings.ai.error_threshold,
learningRate: this.settings.ai.learning_rate,
momentum: this.settings.ai.momentum,
timeout: Infinity
});
await model.save(`file://${destination_filename}`);
// console.log(result);

View file

@ -28,10 +28,7 @@ class DatasetFetcher {
fetch_all(gateway_id) {
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
let result = {
input: [],
output: []
};
let result = [];
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
let distance_from_gateway = haversine(gateway_location, rssi);
@ -47,12 +44,15 @@ class DatasetFetcher {
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
result.output.push([
clamp(normalise(rssi.rssi,
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
{ min: 0, max: 1 }
), 0, 1)
]);
let next_output = clamp(normalise(rssi.rssi,
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
{ min: 0, max: 1 }
), 0, 1);
result.push({
input: next_input,
output: next_output
});
}
for(let reading of this.repo_reading.iterate_unreceived()) {