diff --git a/server/settings.default.toml b/server/settings.default.toml index 2f30918..db68936 100644 --- a/server/settings.default.toml +++ b/server/settings.default.toml @@ -65,8 +65,14 @@ network_arch = [ 64 ] # The number of epochs to train for. epochs = 1000 -# The percentage, between 0 and 1, of data that should be set aside for validation. -validation_split = 0.1 +# Cut training short if the mean squared error drops below +# this value +error_threshold = 0.0001 + +# The learning rate that the neural networks should learn at +learning_rate = 0.1 +# The momentum term to use when learning +momentum = 0.1 # The directory to output trained AIs to, relative to the repository root. output_directory = "app/ais/" diff --git a/server/train-ai/AITrainer.mjs b/server/train-ai/AITrainer.mjs index fd28356..4e85217 100644 --- a/server/train-ai/AITrainer.mjs +++ b/server/train-ai/AITrainer.mjs @@ -70,7 +70,7 @@ class AITrainer { */ async train_gateway(gateway_id, destination_filename) { this.l.log(`${this.a.fgreen}${this.a.hicol}Training AI for gateway ${gateway_id}${this.a.reset}`); - let model = this.generate_neural_net(); + let net = this.generate_neural_net(); // let dataset_input = tf.data.generator( // this.dataset_fetcher.fetch_input.bind(this.dataset_fetcher, gateway_id) @@ -85,17 +85,17 @@ class AITrainer { // }).shuffle(this.settings.ai.batch_size * 4) // .batch(this.settings.ai.batch_size); // - let datasets = this.dataset_fetcher.fetch_all(gateway_id); + let dataset = this.dataset_fetcher.fetch_all(gateway_id); - let result = await model.fit( - tf.tensor(datasets.input), - tf.tensor(datasets.output), - { - epochs: this.settings.ai.epochs, - batchSize: this.settings.ai.batch_size, - validationSplit: this.settings.ai.validation_split - } - ); + let result = net.train(dataset, { + iterations: this.settings.ai.epochs, + errorThresh: this.settings.ai.error_threshold, + + learningRate: this.settings.ai.learning_rate, + momentum: this.settings.ai.momentum, + + timeout: Infinity + }); await model.save(`file://${destination_filename}`); // console.log(result); diff --git a/server/train-ai/DatasetFetcher.mjs b/server/train-ai/DatasetFetcher.mjs index 2217a46..7b82c44 100644 --- a/server/train-ai/DatasetFetcher.mjs +++ b/server/train-ai/DatasetFetcher.mjs @@ -28,10 +28,7 @@ class DatasetFetcher { fetch_all(gateway_id) { let gateway_location = this.repo_gateway.get_by_id(gateway_id); - let result = { - input: [], - output: [] - }; + let result = []; for(let rssi of this.repo_rssi.iterate_gateway(gateway_id)) { let next_input = this.normalise_latlng(rssi.latitude, rssi.longitude); let distance_from_gateway = haversine(gateway_location, rssi); @@ -47,12 +44,15 @@ class DatasetFetcher { console.log(`Distance from gateway: ${haversine(gateway_location, rssi)}m`); - result.output.push([ - clamp(normalise(rssi.rssi, - { min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max }, - { min: 0, max: 1 } - ), 0, 1) - ]); + let next_output = clamp(normalise(rssi.rssi, + { min: this.settings.ai.rssi_min, max: this.settings.ai.rssi_max }, + { min: 0, max: 1 } + ), 0, 1); + + result.push({ + input: next_input, + output: next_output + }); } for(let reading of this.repo_reading.iterate_unreceived()) {