Write the the data input, but it's untested.

There's also some confusion over the batch size. If it doesn't work, we 
probably need to  revise our understanding of the batch size in 
different contexts.
This commit is contained in:
Starbeamrainbowlabs 2019-07-18 16:34:25 +01:00
parent 6dbb6c3b87
commit 192e31a925
3 changed files with 52 additions and 27 deletions

View File

@ -59,6 +59,14 @@ devices = [
rssi_min = -150
rssi_max = 0
# Data is streamed from the SQLite databse. The batch size specifies how many
# rows to gather for training at once.
# Note that the entire available dataset will eventually end up being used -
# this setting just controls how much of ti is in memory at once.
batch_size = 32
# The number of epochs to train for.
epochs = 5
[logging]
# The format the date displayed when logging things should take.
# Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely))

View File

@ -31,15 +31,31 @@ class AITrainer {
return model;
}
async train() {
async train_all() {
for(let gateway of this.repo_gateway.iterate()) {
let dataset = this.dataset_fetcher.fetch(gateway.id);
await this.train_dataset(dataset);
await this.train_dataset(gateway.id);
}
}
async train_dataset(dataset) {
// TODO: Fill this in
async train_gateway(gateway_id) {
let dataset_input = tf.data.generator(
this.dataset_fetcher.fetch_input.bind(null, gateway_id)
);
let dataset_output = tf.data.generator(
this.dataset_fetcher.fetch_output.bind(null, gateway_id)
);
let dataset = tf.data.zip({
xs: dataset_input,
ys: dataset_output
}).shuffle(this.settings.ai.batch_size)
.batch(this.settings.ai.batch_size);
let result = await this.model.fitDataset(dataset, {
epochs: this.settings.ai.epochs,
batchSize: this.settings.ai.batch_size
});
console.log(result);
}
}

View File

@ -8,29 +8,30 @@ class DatasetFetcher {
this.repo_rssi = RSSIRepo;
}
fetch(gateway_id) {
let result = [];
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id) {
result.push({
input: [
normalise(rssi.latitude,
{ min: -90, max: +90 },
{ min: 0, max: 1 }
),
normalise(rssi.longitude,
{ min: -180, max: +180 },
{ min: 0, max: 1 }
)
],
output: [
clamp(normalise(rssis.rssi,
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
{ min: 0, max: 1 }
), 0, 1)
]
});
*fetch_input(gateway_id) {
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
yield [
normalise(rssi.latitude,
{ min: -90, max: +90 },
{ min: 0, max: 1 }
),
normalise(rssi.longitude,
{ min: -180, max: +180 },
{ min: 0, max: 1 }
)
];
}
}
*fetch_output(gateway_id) {
for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) {
yield [
clamp(normalise(rssi.rssi,
{ min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max },
{ min: 0, max: 1 }
), 0, 1)
];
}
return result;
}
}