Continue brain.js refactoring, but it's still not finished yet.
This commit is contained in:
parent
2a81ac792e
commit
cce0761fed
3 changed files with 29 additions and 23 deletions
|
@ -65,8 +65,14 @@ network_arch = [ 64 ]
|
|||
|
||||
# The number of epochs to train for.
|
||||
epochs = 1000
|
||||
# The percentage, between 0 and 1, of data that should be set aside for validation.
|
||||
validation_split = 0.1
|
||||
# Cut training short if the mean squared error drops below
|
||||
# this value
|
||||
error_threshold = 0.0001
|
||||
|
||||
# The learning rate that the neural networks should learn at
|
||||
learning_rate = 0.1
|
||||
# The momentum term to use when learning
|
||||
momentum = 0.1
|
||||
|
||||
# The directory to output trained AIs to, relative to the repository root.
|
||||
output_directory = "app/ais/"
|
||||
|
|
|
@ -70,7 +70,7 @@ class AITrainer {
|
|||
*/
|
||||
async train_gateway(gateway_id, destination_filename) {
|
||||
this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`);
|
||||
let model = this.generate_neural_net();
|
||||
let net = this.generate_neural_net();
|
||||
|
||||
// let dataset_input = tf.data.generator(
|
||||
// this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||
|
@ -85,17 +85,17 @@ class AITrainer {
|
|||
// }).shuffle(this.settings.ai.batch_size * 4)
|
||||
// .batch(this.settings.ai.batch_size);
|
||||
//
|
||||
let datasets = this.dataset_fetcher.fetch_all(gateway_id);
|
||||
let dataset = this.dataset_fetcher.fetch_all(gateway_id);
|
||||
|
||||
let result = await model.fit(
|
||||
tf.tensor(datasets.input),
|
||||
tf.tensor(datasets.output),
|
||||
{
|
||||
epochs: this.settings.ai.epochs,
|
||||
batchSize: this.settings.ai.batch_size,
|
||||
validationSplit: this.settings.ai.validation_split
|
||||
}
|
||||
);
|
||||
let result = net.train(dataset, {
|
||||
iterations: this.settings.ai.epochs,
|
||||
errorThresh: this.settings.ai.error_threshold,
|
||||
|
||||
learningRate: this.settings.ai.learning_rate,
|
||||
momentum: this.settings.ai.momentum,
|
||||
|
||||
timeout: Infinity
|
||||
});
|
||||
|
||||
await model.save(`file://${destination_filename}`);
|
||||
// console.log(result);
|
||||
|
|
|
@ -28,10 +28,7 @@ class DatasetFetcher {
|
|||
fetch_all(gateway_id) {
|
||||
let gateway_location = this.repo_gateway.get_by_id(gateway_id);
|
||||
|
||||
let result = {
|
||||
input: [],
|
||||
output: []
|
||||
};
|
||||
let result = [];
|
||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||
let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude);
|
||||
let distance_from_gateway = haversine(gateway_location, rssi);
|
||||
|
@ -47,12 +44,15 @@ class DatasetFetcher {
|
|||
|
||||
console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`);
|
||||
|
||||
result.output.push([
|
||||
clamp(normalise(rssi.rssi,
|
||||
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||
{ min: 0, max: 1 }
|
||||
), 0, 1)
|
||||
]);
|
||||
let next_output = clamp(normalise(rssi.rssi,
|
||||
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||
{ min: 0, max: 1 }
|
||||
), 0, 1);
|
||||
|
||||
result.push({
|
||||
input: next_input,
|
||||
output: next_output
|
||||
});
|
||||
}
|
||||
|
||||
for(let reading of this.repo_reading.iterate_unreceived()) {
|
||||
|
|
Loading…
Reference in a new issue