Fix bugs in AI trainer. It works, but precisely it's doing si up for debate....
This commit is contained in:
parent
192e31a925
commit
5c7cc1e906
4 changed files with 21 additions and 10 deletions
|
@ -29,7 +29,7 @@ class RSSIRepo {
|
|||
readings.latitude,
|
||||
readings.longitude
|
||||
FROM rssis
|
||||
JOIN readings ON rssis.gateway_id = readings.id
|
||||
JOIN readings ON rssis.reading_id = readings.id
|
||||
WHERE gateway_id = :gateway_id`).iterate({
|
||||
gateway_id
|
||||
});
|
||||
|
|
|
@ -77,6 +77,12 @@ export default async function(c) {
|
|||
|
||||
break;
|
||||
|
||||
case "train-ai":
|
||||
l.log(`${a.fgreen}${a.hicol}Training AIs${a.reset}`);
|
||||
let ai_trainer = c.cradle.AITrainer;
|
||||
await ai_trainer.train_all();
|
||||
break;
|
||||
|
||||
default:
|
||||
l.error(`Error: Subcommand '${extras[0]}' not recognised.`);
|
||||
l.error(`Perhaps you mistyped it, or it hasn't been implemented yet?`);
|
||||
|
|
|
@ -38,7 +38,7 @@ export default async function setup() {
|
|||
DataProcessor: a.asClass(DataProcessor),
|
||||
|
||||
AITrainer: a.asClass(AITrainer),
|
||||
DatasetTrainer: a.asClass(DatasetTrainer)
|
||||
DatasetFetcher: a.asClass(DatasetFetcher)
|
||||
});
|
||||
|
||||
// Enable / disable colourising the output
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
import tf from '@tensorflow/tfjs-node-gpu';
|
||||
|
||||
class AITrainer {
|
||||
constructor({ settings, GatewayRepo, DatasetFetcher }) {
|
||||
constructor({ settings, log, GatewayRepo, DatasetFetcher }) {
|
||||
this.settings = settings;
|
||||
this.l = log;
|
||||
this.dataset_fetcher = DatasetFetcher;
|
||||
this.repo_gateway = GatewayRepo;
|
||||
this.model = this.generate_model();
|
||||
|
@ -15,7 +16,7 @@ class AITrainer {
|
|||
model.add(tf.layers.dense({
|
||||
units: 256, // 256 nodes
|
||||
activation: "sigmoid", // Sigmoid activation function
|
||||
inputShape: [3], // 2 inputs - lat and long
|
||||
inputShape: [2], // 2 inputs - lat and long
|
||||
}))
|
||||
model.add(tf.layers.dense({
|
||||
units: 1, // 1 output value - RSSI
|
||||
|
@ -24,33 +25,37 @@ class AITrainer {
|
|||
|
||||
model.compile({
|
||||
optimizer: tf.train.adam(),
|
||||
loss: "absoluteDifference",
|
||||
metrics: [ "accuracy", "meanSquaredError" ]
|
||||
loss: tf.losses.absoluteDifference,
|
||||
metrics: [ tf.metrics.meanSquaredError ]
|
||||
});
|
||||
|
||||
this.l.log(`Model:`);
|
||||
model.summary();
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
async train_all() {
|
||||
for(let gateway of this.repo_gateway.iterate()) {
|
||||
await this.train_dataset(gateway.id);
|
||||
await this.train_gateway(gateway.id);
|
||||
}
|
||||
}
|
||||
|
||||
async train_gateway(gateway_id) {
|
||||
let dataset_input = tf.data.generator(
|
||||
this.dataset_fetcher.fetch_input.bind(null, gateway_id)
|
||||
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||
);
|
||||
let dataset_output = tf.data.generator(
|
||||
this.dataset_fetcher.fetch_output.bind(null, gateway_id)
|
||||
this.dataset_fetcher.fetch_output.bind(this.dataset_fetcher, gateway_id)
|
||||
);
|
||||
|
||||
let dataset = tf.data.zip({
|
||||
xs: dataset_input,
|
||||
ys: dataset_output
|
||||
}).shuffle(this.settings.ai.batch_size)
|
||||
}).shuffle(this.settings.ai.batch_size * 4)
|
||||
.batch(this.settings.ai.batch_size);
|
||||
|
||||
|
||||
let result = await this.model.fitDataset(dataset, {
|
||||
epochs: this.settings.ai.epochs,
|
||||
batchSize: this.settings.ai.batch_size
|
||||
|
|
Loading…
Reference in a new issue