Write the the data input, but it's untested.
There's also some confusion over the batch size. If it doesn't work, we probably need to revise our understanding of the batch size in different contexts.
This commit is contained in:
parent
6dbb6c3b87
commit
192e31a925
3 changed files with 52 additions and 27 deletions
|
@ -59,6 +59,14 @@ devices = [
|
||||||
rssi_min = -150
|
rssi_min = -150
|
||||||
rssi_max = 0
|
rssi_max = 0
|
||||||
|
|
||||||
|
# Data is streamed from the SQLite databse. The batch size specifies how many
|
||||||
|
# rows to gather for training at once.
|
||||||
|
# Note that the entire available dataset will eventually end up being used -
|
||||||
|
# this setting just controls how much of ti is in memory at once.
|
||||||
|
batch_size = 32
|
||||||
|
# The number of epochs to train for.
|
||||||
|
epochs = 5
|
||||||
|
|
||||||
[logging]
|
[logging]
|
||||||
# The format the date displayed when logging things should take.
|
# The format the date displayed when logging things should take.
|
||||||
# Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely))
|
# Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely))
|
||||||
|
|
|
@ -31,15 +31,31 @@ class AITrainer {
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
async train() {
|
async train_all() {
|
||||||
for(let gateway of this.repo_gateway.iterate()) {
|
for(let gateway of this.repo_gateway.iterate()) {
|
||||||
let dataset = this.dataset_fetcher.fetch(gateway.id);
|
await this.train_dataset(gateway.id);
|
||||||
await this.train_dataset(dataset);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async train_dataset(dataset) {
|
async train_gateway(gateway_id) {
|
||||||
// TODO: Fill this in
|
let dataset_input = tf.data.generator(
|
||||||
|
this.dataset_fetcher.fetch_input.bind(null, gateway_id)
|
||||||
|
);
|
||||||
|
let dataset_output = tf.data.generator(
|
||||||
|
this.dataset_fetcher.fetch_output.bind(null, gateway_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
let dataset = tf.data.zip({
|
||||||
|
xs: dataset_input,
|
||||||
|
ys: dataset_output
|
||||||
|
}).shuffle(this.settings.ai.batch_size)
|
||||||
|
.batch(this.settings.ai.batch_size);
|
||||||
|
|
||||||
|
let result = await this.model.fitDataset(dataset, {
|
||||||
|
epochs: this.settings.ai.epochs,
|
||||||
|
batchSize: this.settings.ai.batch_size
|
||||||
|
});
|
||||||
|
console.log(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,9 @@ class DatasetFetcher {
|
||||||
this.repo_rssi = RSSIRepo;
|
this.repo_rssi = RSSIRepo;
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch(gateway_id) {
|
*fetch_input(gateway_id) {
|
||||||
let result = [];
|
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||||
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id) {
|
yield [
|
||||||
result.push({
|
|
||||||
input: [
|
|
||||||
normalise(rssi.latitude,
|
normalise(rssi.latitude,
|
||||||
{ min: -90, max: +90 },
|
{ min: -90, max: +90 },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
|
@ -21,16 +19,19 @@ class DatasetFetcher {
|
||||||
{ min: -180, max: +180 },
|
{ min: -180, max: +180 },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
)
|
)
|
||||||
],
|
];
|
||||||
output: [
|
}
|
||||||
clamp(normalise(rssis.rssi,
|
}
|
||||||
|
|
||||||
|
*fetch_output(gateway_id) {
|
||||||
|
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
|
||||||
|
yield [
|
||||||
|
clamp(normalise(rssi.rssi,
|
||||||
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
|
||||||
{ min: 0, max: 1 }
|
{ min: 0, max: 1 }
|
||||||
), 0, 1)
|
), 0, 1)
|
||||||
]
|
];
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue