Serialise trained AIs and save them to disk
This commit is contained in:
parent
dd9d39ba52
commit
92574bc98c
2 changed files with 30 additions and 3 deletions
|
@ -67,6 +67,9 @@ batch_size = 32
|
||||||
# The number of epochs to train for.
|
# The number of epochs to train for.
|
||||||
epochs = 5
|
epochs = 5
|
||||||
|
|
||||||
|
# The directory to output trained AIs to, relative to the repository root.
|
||||||
|
output_directory = "app/ais/"
|
||||||
|
|
||||||
[logging]
|
[logging]
|
||||||
# The format the date displayed when logging things should take.
|
# The format the date displayed when logging things should take.
|
||||||
# Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely))
|
# Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely))
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
"use strict";
|
"use strict";
|
||||||
|
|
||||||
|
import path from 'path';
|
||||||
|
import fs from 'fs';
|
||||||
|
|
||||||
import tf from '@tensorflow/tfjs-node-gpu';
|
import tf from '@tensorflow/tfjs-node-gpu';
|
||||||
|
|
||||||
class AITrainer {
|
class AITrainer {
|
||||||
constructor({ settings, log, GatewayRepo, DatasetFetcher }) {
|
constructor({ settings, log, root_dir, GatewayRepo, DatasetFetcher }) {
|
||||||
this.settings = settings;
|
this.settings = settings;
|
||||||
|
this.root_dir = root_dir;
|
||||||
this.l = log;
|
this.l = log;
|
||||||
this.dataset_fetcher = DatasetFetcher;
|
this.dataset_fetcher = DatasetFetcher;
|
||||||
this.repo_gateway = GatewayRepo;
|
this.repo_gateway = GatewayRepo;
|
||||||
this.model = this.generate_model();
|
this.model = this.generate_model();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_model() {
|
generate_model() {
|
||||||
|
@ -36,12 +41,29 @@ class AITrainer {
|
||||||
}
|
}
|
||||||
|
|
||||||
async train_all() {
|
async train_all() {
|
||||||
|
|
||||||
for(let gateway of this.repo_gateway.iterate()) {
|
for(let gateway of this.repo_gateway.iterate()) {
|
||||||
await this.train_gateway(gateway.id);
|
let filename = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}`);
|
||||||
|
console.log(filename);
|
||||||
|
|
||||||
|
if(!fs.existsSync(path.dirname(filename)))
|
||||||
|
await fs.promises.mkdir(path.dirname(filename), { recursive: true });
|
||||||
|
|
||||||
|
await this.train_gateway(
|
||||||
|
gateway.id,
|
||||||
|
filename
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async train_gateway(gateway_id) {
|
/**
|
||||||
|
* Trains an AI to predict the coverage of a specific gateway.
|
||||||
|
* @param {string} gateway_id The id of the gateway to train an AI for.
|
||||||
|
* @param {string} destination_filename The absolute path to the file to serialise the trained to. Required because we can't serialise and return a TensorFlow model, it has to be sent somewhere because the API is backwards and upside-down :-/
|
||||||
|
* @return {Promise} A promise that resolves when training and serialisation is complete.
|
||||||
|
*/
|
||||||
|
async train_gateway(gateway_id, destination_filename) {
|
||||||
|
// TODO: Add samples here for locations that the gateway does NOT cover too
|
||||||
let dataset_input = tf.data.generator(
|
let dataset_input = tf.data.generator(
|
||||||
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id)
|
||||||
);
|
);
|
||||||
|
@ -60,6 +82,8 @@ class AITrainer {
|
||||||
epochs: this.settings.ai.epochs,
|
epochs: this.settings.ai.epochs,
|
||||||
batchSize: this.settings.ai.batch_size
|
batchSize: this.settings.ai.batch_size
|
||||||
});
|
});
|
||||||
|
|
||||||
|
await this.model.save(`file://${destination_filename}`);
|
||||||
console.log(result);
|
console.log(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue