From 92574bc98c00846ea7fa27c782c8702668e0c49b Mon Sep 17 00:00:00 2001 From: Starbeamrainbowlabs Date: Mon, 22 Jul 2019 11:53:30 +0100 Subject: [PATCH] Serialise trained AIs and save them to disk --- server/settings.default.toml | 3 +++ server/train-ai/AITrainer.mjs | 30 +++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/server/settings.default.toml b/server/settings.default.toml index c828646..606413c 100644 --- a/server/settings.default.toml +++ b/server/settings.default.toml @@ -67,6 +67,9 @@ batch_size = 32 # The number of epochs to train for. epochs = 5 +# The directory to output trained AIs to, relative to the repository root. +output_directory = "app/ais/" + [logging] # The format the date displayed when logging things should take. # Allowed values: relative (e.g like when a Linux machine boots), absolute (e.g. like Nginx server logs), none (omits it entirely)) diff --git a/server/train-ai/AITrainer.mjs b/server/train-ai/AITrainer.mjs index ea5fcb5..39de0f6 100644 --- a/server/train-ai/AITrainer.mjs +++ b/server/train-ai/AITrainer.mjs @@ -1,14 +1,19 @@ "use strict"; +import path from 'path'; +import fs from 'fs'; + import tf from '@tensorflow/tfjs-node-gpu'; class AITrainer { - constructor({ settings, log, GatewayRepo, DatasetFetcher }) { + constructor({ settings, log, root_dir, GatewayRepo, DatasetFetcher }) { this.settings = settings; + this.root_dir = root_dir; this.l = log; this.dataset_fetcher = DatasetFetcher; this.repo_gateway = GatewayRepo; this.model = this.generate_model(); + } generate_model() { @@ -36,12 +41,29 @@ class AITrainer { } async train_all() { + for(let gateway of this.repo_gateway.iterate()) { - await this.train_gateway(gateway.id); + let filename = path.join(this.root_dir, "..", this.settings.ai.output_directory, `${gateway.id}`); + console.log(filename); + + if(!fs.existsSync(path.dirname(filename))) + await fs.promises.mkdir(path.dirname(filename), { recursive: true }); + + await this.train_gateway( + gateway.id, + filename + ); } } - async train_gateway(gateway_id) { + /** + * Trains an AI to predict the coverage of a specific gateway. + * @param {string} gateway_id The id of the gateway to train an AI for. + * @param {string} destination_filename The absolute path to the file to serialise the trained to. Required because we can't serialise and return a TensorFlow model, it has to be sent somewhere because the API is backwards and upside-down :-/ + * @return {Promise} A promise that resolves when training and serialisation is complete. + */ + async train_gateway(gateway_id, destination_filename) { + // TODO: Add samples here for locations that the gateway does NOT cover too let dataset_input = tf.data.generator( this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id) ); @@ -60,6 +82,8 @@ class AITrainer { epochs: this.settings.ai.epochs, batchSize: this.settings.ai.batch_size }); + + await this.model.save(`file://${destination_filename}`); console.log(result); } }