Fix bugs in AI trainer. It works, but precisely it's doing si up for debate....
This commit is contained in:
parent
192e31a925
commit
5c7cc1e906
4 changed files with 21 additions and 10 deletions
|
@ -29,7 +29,7 @@ class RSSIRepo {
|
||||||
readings.latitude,
|
readings.latitude,
|
||||||
readings.longitude
|
readings.longitude
|
||||||
FROM rssis
|
FROM rssis
|
||||||
JOIN readings ON rssis.gateway_id = readings.id
|
JOIN readings ON rssis.reading_id = readings.id
|
||||||
WHERE gateway_id = :gateway_id`).iterate({
|
WHERE gateway_id = :gateway_id`).iterate({
|
||||||
gateway_id
|
gateway_id
|
||||||
});
|
});
|
||||||
|
|
|
@ -77,6 +77,12 @@ export default async function(c) {
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case "train-ai":
|
||||||
|
l.log(`${a.fgreen}${a.hicol}Training AIs${a.reset}`);
|
||||||
|
let ai_trainer = c.cradle.AITrainer;
|
||||||
|
await ai_trainer.train_all();
|
||||||
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
l.error(`Error: Subcommand '${extras[0]}' not recognised.`);
|
l.error(`Error: Subcommand '${extras[0]}' not recognised.`);
|
||||||
l.error(`Perhaps you mistyped it, or it hasn't been implemented yet?`);
|
l.error(`Perhaps you mistyped it, or it hasn't been implemented yet?`);
|
||||||
|
|
|
@ -38,7 +38,7 @@ export default async function setup() {
|
||||||
DataProcessor: a.asClass(DataProcessor),
|
DataProcessor: a.asClass(DataProcessor),
|
||||||
|
|
||||||
AITrainer: a.asClass(AITrainer),
|
AITrainer: a.asClass(AITrainer),
|
||||||
DatasetTrainer: a.asClass(DatasetTrainer)
|
DatasetFetcher: a.asClass(DatasetFetcher)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Enable / disable colourising the output
|
// Enable / disable colourising the output
|
||||||
|
|
|
@ -3,8 +3,9 @@
|
||||||
import tf from '@tensorflow/tfjs-node-gpu';
|
import tf from '@tensorflow/tfjs-node-gpu';
|
||||||
|
|
||||||
class AITrainer {
|
class AITrainer {
|
||||||
constructor({ settings, GatewayRepo, DatasetFetcher }) {
|
constructor({ settings, log, GatewayRepo, DatasetFetcher }) {
|
||||||
this.settings = settings;
|
this.settings = settings;
|
||||||
|
this.l = log;
|
||||||
this.dataset_fetcher = DatasetFetcher;
|
this.dataset_fetcher = DatasetFetcher;
|
||||||
this.repo_gateway = GatewayRepo;
|
this.repo_gateway = GatewayRepo;
|
||||||
this.model = this.generate_model();
|
this.model = this.generate_model();
|
||||||
|
@ -15,7 +16,7 @@ class AITrainer {
|
||||||
model.add(tf.layers.dense({
|
model.add(tf.layers.dense({
|
||||||
units: 256, // 256 nodes
|
units: 256, // 256 nodes
|
||||||
activation: "sigmoid", // Sigmoid activation function
|
activation: "sigmoid", // Sigmoid activation function
|
||||||
inputShape: [3], // 2 inputs - lat and long
|
inputShape: [2], // 2 inputs - lat and long
|
||||||
}))
|
}))
|
||||||
model.add(tf.layers.dense({
|
model.add(tf.layers.dense({
|
||||||
units: 1, // 1 output value - RSSI
|
units: 1, // 1 output value - RSSI
|
||||||
|
@ -24,33 +25,37 @@ class AITrainer {
|
||||||
|
|
||||||
model.compile({
|
model.compile({
|
||||||
optimizer: tf.train.adam(),
|
optimizer: tf.train.adam(),
|
||||||
loss: "absoluteDifference",
|
loss: tf.losses.absoluteDifference,
|
||||||
metrics: [ "accuracy", "meanSquaredError" ]
|
metrics: [ tf.metrics.meanSquaredError ]
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this.l.log(`Model:`);
|
||||||
|
model.summary();
|
||||||
|
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
async train_all() {
|
async train_all() {
|
||||||
for(let gateway of this.repo_gateway.iterate()) {
|
for(let gateway of this.repo_gateway.iterate()) {
|
||||||
await this.train_dataset(gateway.id);
|
await this.train_gateway(gateway.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async train_gateway(gateway_id) {
|
async train_gateway(gateway_id) {
|
||||||
let dataset_input = tf.data.generator(
|
let dataset_input = tf.data.generator(
|
||||||
this.dataset_fetcher.fetch_input.bind(null, gateway_id)
|
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||||
);
|
);
|
||||||
let dataset_output = tf.data.generator(
|
let dataset_output = tf.data.generator(
|
||||||
this.dataset_fetcher.fetch_output.bind(null, gateway_id)
|
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
let dataset = tf.data.zip({
|
let dataset = tf.data.zip({
|
||||||
xs: dataset_input,
|
xs: dataset_input,
|
||||||
ys: dataset_output
|
ys: dataset_output
|
||||||
}).shuffle(this.settings.ai.batch_size)
|
}).shuffle(this.settings.ai.batch_size * 4)
|
||||||
.batch(this.settings.ai.batch_size);
|
.batch(this.settings.ai.batch_size);
|
||||||
|
|
||||||
|
|
||||||
let result = await this.model.fitDataset(dataset, {
|
let result = await this.model.fitDataset(dataset, {
|
||||||
epochs: this.settings.ai.epochs,
|
epochs: this.settings.ai.epochs,
|
||||||
batchSize: this.settings.ai.batch_size
|
batchSize: this.settings.ai.batch_size
|
||||||
|
|
Loading…
Reference in a new issue