Continue brain.js refactoring, but it's still not finished yet.
This commit is contained in:
parent
2a81ac792e
commit
cce0761fed
3 changed files with 29 additions and 23 deletions
|
@ -65,8 +65,14 @@ network_arch = [ 64 ]
|
||||||
|
|
||||||
# The number of epochs to train for.
|
# The number of epochs to train for.
|
||||||
epochs = 1000
|
epochs = 1000
|
||||||
# The percentage, between 0 and 1, of data that should be set aside for validation.
|
# Cut training short if the mean squared error drops below
|
||||||
validation_split = 0.1
|
# this value
|
||||||
|
error_threshold = 0.0001
|
||||||
|
|
||||||
|
# The learning rate that the neural networks should learn at
|
||||||
|
learning_rate = 0.1
|
||||||
|
# The momentum term to use when learning
|
||||||
|
momentum = 0.1
|
||||||
|
|
||||||
# The directory to output trained AIs to, relative to the repository root.
|
# The directory to output trained AIs to, relative to the repository root.
|
||||||
output_directory = "app/ais/"
|
output_directory = "app/ais/"
|
||||||
|
|
|
@ -70,7 +70,7 @@ class AITrainer {
|
||||||
*/
|
*/
|
||||||
async train_gateway(gateway_id, destination_filename) {
|
async train_gateway(gateway_id, destination_filename) {
|
||||||
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
||||||
let model = this.generate_neural_net();
|
let net = this.generate_neural_net();
|
||||||
|
|
||||||
// let dataset_input = tf.data.generator(
|
// let dataset_input = tf.data.generator(
|
||||||
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||||
|
@ -85,17 +85,17 @@ class AITrainer {
|
||||||
// }).shuffle(this.settings.ai.batch_size * 4)
|
// }).shuffle(this.settings.ai.batch_size * 4)
|
||||||
// .batch(this.settings.ai.batch_size);
|
// .batch(this.settings.ai.batch_size);
|
||||||
//
|
//
|
||||||
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
|
let dataset = this.dataset_fetcher.fetch_all(gateway_id);
|
||||||
|
|
||||||
let result = await model.fit(
|
let result = net.train(dataset, {
|
||||||
tf.tensor(datasets.input),
|
iterations: this.settings.ai.epochs,
|
||||||
tf.tensor(datasets.output),
|
errorThresh: this.settings.ai.error_threshold,
|
||||||
{
|
|
||||||
epochs: this.settings.ai.epochs,
|
learningRate: this.settings.ai.learning_rate,
|
||||||
batchSize: this.settings.ai.batch_size,
|
momentum: this.settings.ai.momentum,
|
||||||
validationSplit: this.settings.ai.validation_split
|
|
||||||
}
|
timeout: Infinity
|
||||||
);
|
});
|
||||||
|
|
||||||
await model.save(`file://${destination_filename}`);
|
await model.save(`file://${destination_filename}`);
|
||||||
// console.log(result);
|
// console.log(result);
|
||||||
|
|
|
@ -28,10 +28,7 @@ class DatasetFetcher {
|
||||||
fetch_all(gateway_id) {
|
fetch_all(gateway_id) {
|
||||||
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
|
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
|
||||||
|
|
||||||
let result = {
|
let result = [];
|
||||||
input: [],
|
|
||||||
output: []
|
|
||||||
};
|
|
||||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||||
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||||
let distance_from_gateway = haversine(gateway_location, rssi);
|
let distance_from_gateway = haversine(gateway_location, rssi);
|
||||||
|
@ -47,12 +44,15 @@ class DatasetFetcher {
|
||||||
|
|
||||||
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
|
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
|
||||||
|
|
||||||
result.output.push([
|
let next_output = clamp(normalise(rssi.rssi,
|
||||||
clamp(normalise(rssi.rssi,
|
|
||||||
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
), 0, 1)
|
), 0, 1);
|
||||||
]);
|
|
||||||
|
result.push({
|
||||||
|
input: next_input,
|
||||||
|
output: next_output
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for(let reading of this.repo_reading.iterate_unreceived()) {
|
for(let reading of this.repo_reading.iterate_unreceived()) {
|
||||||
|
|
Loading…
Reference in a new issue