Compare commits
5 commits
master
...
post-talk-
Author | SHA1 | Date | |
---|---|---|---|
442c2080fb | |||
a5258f7460 | |||
10138cd310 | |||
926a82bebf | |||
00075b1823 |
10 changed files with 225 additions and 130 deletions
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
The example code & slide deck for a talk I gave on getting started with AI. A link to the unlisted YouTube video is available upon request (because it contains my face, and this is a public repo) - see [my website](https://starbeamrainbowlabs.com/) for ways to get in touch.
|
The example code & slide deck for a talk I gave on getting started with AI. A link to the unlisted YouTube video is available upon request (because it contains my face, and this is a public repo) - see [my website](https://starbeamrainbowlabs.com/) for ways to get in touch.
|
||||||
|
|
||||||
For advanced users, see also the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
|
This is the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
|
||||||
|
|
||||||
|
|
||||||
## Dataset
|
## Dataset
|
||||||
|
@ -16,6 +16,7 @@ The format is as follows:
|
||||||
|
|
||||||
```
|
```
|
||||||
+ dataset_dir/
|
+ dataset_dir/
|
||||||
|
+ categories.txt
|
||||||
+ train/
|
+ train/
|
||||||
+ 1234,Comedy,Fantasy.jpg
|
+ 1234,Comedy,Fantasy.jpg
|
||||||
+ 4321,Western,Short,Drama.jpg
|
+ 4321,Western,Short,Drama.jpg
|
||||||
|
@ -28,6 +29,9 @@ The format is as follows:
|
||||||
|
|
||||||
The filenames of the images take the following format: `ID,GENRE_1,GENRE_2,GENRE_N.jpg`.
|
The filenames of the images take the following format: `ID,GENRE_1,GENRE_2,GENRE_N.jpg`.
|
||||||
|
|
||||||
|
The `categories.txt` file should contain 1 category name per line (the order matters, as the line numbers - starting from 0 - are used as the numerical ids when training the model).
|
||||||
|
|
||||||
|
|
||||||
## System / User Requirements
|
## System / User Requirements
|
||||||
- [Node.js](https://nodejs.org/)
|
- [Node.js](https://nodejs.org/)
|
||||||
- [NPM](https://www.npmjs.com/) (installed by default with Node.js)
|
- [NPM](https://www.npmjs.com/) (installed by default with Node.js)
|
||||||
|
|
90
package-lock.json
generated
90
package-lock.json
generated
|
@ -5,16 +5,16 @@
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@tensorflow/tfjs": {
|
"@tensorflow/tfjs": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.8.0.tgz",
|
||||||
"integrity": "sha512-LTYK6+emFweYa3zn/o511JUR6s14/yGZpoXvFSUtdwolYHI+J50r/CyYeFpvtoTD7uwcNFQhbBAtp4L4e3Hsaw==",
|
"integrity": "sha512-jIG+WqaBq3NsTU5SnJuxqfgkKLpjVPlU2dbVepKgEqVY7XdffiJEoLBslOSgEJQc0j3GskOmZRw4F9u3NqgB1g==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"@tensorflow/tfjs-backend-cpu": "2.7.0",
|
"@tensorflow/tfjs-backend-cpu": "2.8.0",
|
||||||
"@tensorflow/tfjs-backend-webgl": "2.7.0",
|
"@tensorflow/tfjs-backend-webgl": "2.8.0",
|
||||||
"@tensorflow/tfjs-converter": "2.7.0",
|
"@tensorflow/tfjs-converter": "2.8.0",
|
||||||
"@tensorflow/tfjs-core": "2.7.0",
|
"@tensorflow/tfjs-core": "2.8.0",
|
||||||
"@tensorflow/tfjs-data": "2.7.0",
|
"@tensorflow/tfjs-data": "2.8.0",
|
||||||
"@tensorflow/tfjs-layers": "2.7.0",
|
"@tensorflow/tfjs-layers": "2.8.0",
|
||||||
"argparse": "^1.0.10",
|
"argparse": "^1.0.10",
|
||||||
"chalk": "^4.1.0",
|
"chalk": "^4.1.0",
|
||||||
"core-js": "3",
|
"core-js": "3",
|
||||||
|
@ -23,9 +23,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-backend-cpu": {
|
"@tensorflow/tfjs-backend-cpu": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.8.0.tgz",
|
||||||
"integrity": "sha512-R6ORcWq3ub81ABvBZEZ8Ok5OOT59B4AsRe66ds7B/NK0nN+k6y37bR3ZDVjgkEKNWNvzB7ydODikge3GNmgQIQ==",
|
"integrity": "sha512-RmseKyCWJR0Jz7HyRXm6VVjpaM9Rbql9Vr7Jrx4SCpYX8BzrRVxDYq9aMZqYh93qBj4xX5upoQ5y5zoaFjpbTw==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"@types/seedrandom": "2.4.27",
|
"@types/seedrandom": "2.4.27",
|
||||||
"seedrandom": "2.4.3"
|
"seedrandom": "2.4.3"
|
||||||
|
@ -39,11 +39,11 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-backend-webgl": {
|
"@tensorflow/tfjs-backend-webgl": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.8.0.tgz",
|
||||||
"integrity": "sha512-K7Rk5YTSWOZ969EZvh3w786daPn2ub4mA2JsX7mXKhBPUaOP9dKbBdLj9buCuMcu4zVq2pAp0QwpHSa4PHm3xg==",
|
"integrity": "sha512-GZNpRyXQBHe3Tkt5m63HjgGyklfvU5kAoIuZe+jyfznzxkn3dhmiiLyMByQkxuc0LMZD1987riv0ALMSAXvVbQ==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"@tensorflow/tfjs-backend-cpu": "2.7.0",
|
"@tensorflow/tfjs-backend-cpu": "2.8.0",
|
||||||
"@types/offscreencanvas": "~2019.3.0",
|
"@types/offscreencanvas": "~2019.3.0",
|
||||||
"@types/seedrandom": "2.4.27",
|
"@types/seedrandom": "2.4.27",
|
||||||
"@types/webgl-ext": "0.0.30",
|
"@types/webgl-ext": "0.0.30",
|
||||||
|
@ -59,14 +59,14 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-converter": {
|
"@tensorflow/tfjs-converter": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.8.0.tgz",
|
||||||
"integrity": "sha512-SBpKYn/MkN8US7DeTcnvqHpvp/WKcwzpdgkQF+eHMHEbS1lXSlt4BHhOFgRdLPzy1gEC9+6P0VdTE8NQ737t/Q=="
|
"integrity": "sha512-GvVzPu0+gtgfq0+UAA2mHV7z5TT359TAUUJHOI7/dDSlfCMfLx5bSa0CamW7PIPkzSn9urljXTbocdeNAdpJKQ=="
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-core": {
|
"@tensorflow/tfjs-core": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.8.0.tgz",
|
||||||
"integrity": "sha512-4w5zjK6C5nkLatHpzARVQNd5QKtIocJRwjZIwWcScT9z2z1dX4rVmDoUpYg1cdD4H+yRRdI0awRaI3SL34yy8Q==",
|
"integrity": "sha512-MMU5wG9bRPccFjxfc5vORw8YkjfFrPUtnoPQ/1WOqud5l3z3a318WAGqzSloPUMOMPD0c7x0vPQBsUgteliC2w==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"@types/offscreencanvas": "~2019.3.0",
|
"@types/offscreencanvas": "~2019.3.0",
|
||||||
"@types/seedrandom": "2.4.27",
|
"@types/seedrandom": "2.4.27",
|
||||||
|
@ -83,26 +83,26 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-data": {
|
"@tensorflow/tfjs-data": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.8.0.tgz",
|
||||||
"integrity": "sha512-gsVklCwqqlxhykI7U2Uy5c2hjommQCAi+3y2/LER4TNtzQTzWaGKyIXvuLuL0tE896yuzXILIMZhkUjDmUiGxA==",
|
"integrity": "sha512-U0UbL6Hffcq6O6r6Isb6cPM9GxXUcU998gcPSIBeTMaGnBGbm/tl6Nf7f8SbGV4076ojeeFr4+0ly8heyBnfew==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"@types/node-fetch": "^2.1.2",
|
"@types/node-fetch": "^2.1.2",
|
||||||
"node-fetch": "~2.6.1"
|
"node-fetch": "~2.6.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-layers": {
|
"@tensorflow/tfjs-layers": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.8.0.tgz",
|
||||||
"integrity": "sha512-78zsD2LLrHQuDYv0EeV83LiF0M69lKsBfuTB3FIBgS85gapZPyHh4wooKda2Y4H9EtLogU+C6bArZuDo8PaX+g=="
|
"integrity": "sha512-g7ZEXyo46osVbJAR90KXbwWr82OlcqZOohSEIMXSZ/egnXLQpQqUBmGg8p8Nw2es6jmpSfOBDivv1z59HExVvg=="
|
||||||
},
|
},
|
||||||
"@tensorflow/tfjs-node": {
|
"@tensorflow/tfjs-node": {
|
||||||
"version": "2.7.0",
|
"version": "2.8.0",
|
||||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.8.0.tgz",
|
||||||
"integrity": "sha512-0cWplm7AE40gi2llqoAp+lD/0X3dVJ8kb7Arrqb5lMhShRWUFZpULH+F0fJI6Yax4LBTzBi2SZKGL/O8krZsxg==",
|
"integrity": "sha512-u62AqvIS9jaMrw09Hl/Dy1YKiqEkltQ1c+RoJuW0XE685bjRue7FRmkJy69+EvsnXLY6gvPR1y+LguTPql8ZwA==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"@tensorflow/tfjs": "2.7.0",
|
"@tensorflow/tfjs": "2.8.0",
|
||||||
"@tensorflow/tfjs-core": "2.7.0",
|
"@tensorflow/tfjs-core": "2.8.0",
|
||||||
"adm-zip": "^0.4.11",
|
"adm-zip": "^0.4.11",
|
||||||
"google-protobuf": "^3.9.2",
|
"google-protobuf": "^3.9.2",
|
||||||
"https-proxy-agent": "^2.2.1",
|
"https-proxy-agent": "^2.2.1",
|
||||||
|
@ -113,9 +113,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@types/node": {
|
"@types/node": {
|
||||||
"version": "14.14.7",
|
"version": "14.14.14",
|
||||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.7.tgz",
|
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.14.tgz",
|
||||||
"integrity": "sha512-Zw1vhUSQZYw+7u5dAwNbIA9TuTotpzY/OF7sJM9FqPOF3SPjKnxrjoTktXDZgUjybf4cWVBP7O8wvKdSaGHweg=="
|
"integrity": "sha512-UHnOPWVWV1z+VV8k6L1HhG7UbGBgIdghqF3l9Ny9ApPghbjICXkUJSd/b9gOgQfjM1r+37cipdw/HJ3F6ICEnQ=="
|
||||||
},
|
},
|
||||||
"@types/node-fetch": {
|
"@types/node-fetch": {
|
||||||
"version": "2.5.7",
|
"version": "2.5.7",
|
||||||
|
@ -314,9 +314,9 @@
|
||||||
"integrity": "sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4="
|
"integrity": "sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4="
|
||||||
},
|
},
|
||||||
"core-js": {
|
"core-js": {
|
||||||
"version": "3.7.0",
|
"version": "3.8.1",
|
||||||
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.8.1.tgz",
|
||||||
"integrity": "sha512-NwS7fI5M5B85EwpWuIwJN4i/fbisQUwLwiSNUWeXlkAZ0sbBjLEvLvFLf1uzAUV66PcEPt4xCGCmOZSxVf3xzA=="
|
"integrity": "sha512-9Id2xHY1W7m8hCl8NkhQn5CufmF/WuR30BTRewvCXc1aZd3kMECwNZ69ndLbekKfakw9Rf2Xyc+QR6E7Gg+obg=="
|
||||||
},
|
},
|
||||||
"core-util-is": {
|
"core-util-is": {
|
||||||
"version": "1.0.2",
|
"version": "1.0.2",
|
||||||
|
@ -485,9 +485,9 @@
|
||||||
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
|
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
|
||||||
},
|
},
|
||||||
"ini": {
|
"ini": {
|
||||||
"version": "1.3.5",
|
"version": "1.3.8",
|
||||||
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.5.tgz",
|
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
|
||||||
"integrity": "sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw=="
|
"integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
|
||||||
},
|
},
|
||||||
"is-fullwidth-code-point": {
|
"is-fullwidth-code-point": {
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
|
@ -885,9 +885,9 @@
|
||||||
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="
|
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="
|
||||||
},
|
},
|
||||||
"yargs": {
|
"yargs": {
|
||||||
"version": "16.1.1",
|
"version": "16.2.0",
|
||||||
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.2.0.tgz",
|
||||||
"integrity": "sha512-hAD1RcFP/wfgfxgMVswPE+z3tlPFtxG8/yWUrG2i17sTWGCGqWnxKcLTF4cUKDUK8fzokwsmO9H0TDkRbMHy8w==",
|
"integrity": "sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw==",
|
||||||
"requires": {
|
"requires": {
|
||||||
"cliui": "^7.0.2",
|
"cliui": "^7.0.2",
|
||||||
"escalade": "^3.1.1",
|
"escalade": "^3.1.1",
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
"author": "Starbeamrainbowlabs",
|
"author": "Starbeamrainbowlabs",
|
||||||
"license": "MPL-2.0",
|
"license": "MPL-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@tensorflow/tfjs-node": "^2.7.0",
|
"@tensorflow/tfjs-node": "^2.8.0",
|
||||||
"applause-cli": "^1.5.1"
|
"applause-cli": "^1.5.1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,10 @@ export default async function () {
|
||||||
let cli = new CliParser(path.resolve(__dirname, "../package.json"));
|
let cli = new CliParser(path.resolve(__dirname, "../package.json"));
|
||||||
cli.subcommand("train", "Trains a new AI")
|
cli.subcommand("train", "Trains a new AI")
|
||||||
.argument("input", "The input directory containing the training data", null, "string")
|
.argument("input", "The input directory containing the training data", null, "string")
|
||||||
.argument("output", "Path to the output directory to save the trained AI to");
|
.argument("output", "Path to the output directory to save the trained AI to")
|
||||||
|
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
|
||||||
|
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
|
||||||
|
.argument("activation", "Set the activation function for the CNN layer(s) (default: relu)", "relu", "string");
|
||||||
cli.subcommand("predict", "Predicts the genres of the specified image")
|
cli.subcommand("predict", "Predicts the genres of the specified image")
|
||||||
.argument("input", "Path to the input image")
|
.argument("input", "Path to the input image")
|
||||||
.argument("ai-model", "Path to the saved AI model to load");
|
.argument("ai-model", "Path to the saved AI model to load");
|
||||||
|
|
40
src/lib/Categories.mjs
Normal file
40
src/lib/Categories.mjs
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
"use strict";
|
||||||
|
|
||||||
|
import fs from 'fs';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a list of categories.
|
||||||
|
*/
|
||||||
|
class Categories {
|
||||||
|
constructor(in_filename) {
|
||||||
|
this.filename = in_filename;
|
||||||
|
if(!fs.existsSync(this.filename))
|
||||||
|
throw new Error(`Error: No such file or directory '${this.filename}' (it should exist and have 1 category name per line).`);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of category names.
|
||||||
|
* @type {string[]}
|
||||||
|
*/
|
||||||
|
this.values = fs.readFileSync(this.filename, "utf-8")
|
||||||
|
.replace(/\r\n?/, "\n")
|
||||||
|
.split("\n")
|
||||||
|
.map((line) => line.trim())
|
||||||
|
.filter((line) => line.length > 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The total number of categories.
|
||||||
|
* @type {Number}
|
||||||
|
*/
|
||||||
|
this.count = this.values.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
to_id(cat_name) {
|
||||||
|
return this.values.indexOf(cat_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
to_name(id) {
|
||||||
|
return this.values[id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Categories;
|
|
@ -5,7 +5,7 @@ import path from 'path';
|
||||||
|
|
||||||
import tf from '@tensorflow/tfjs-node';
|
import tf from '@tensorflow/tfjs-node';
|
||||||
|
|
||||||
import genres from './Genres.mjs';
|
import Categories from './Categories.mjs';
|
||||||
|
|
||||||
class DataPreprocessor {
|
class DataPreprocessor {
|
||||||
constructor(dir_input) {
|
constructor(dir_input) {
|
||||||
|
@ -14,21 +14,14 @@ class DataPreprocessor {
|
||||||
|
|
||||||
this.dir_input_train = path.join(dir_input, "train");
|
this.dir_input_train = path.join(dir_input, "train");
|
||||||
this.dir_input_validate = path.join(dir_input, "validate");
|
this.dir_input_validate = path.join(dir_input, "validate");
|
||||||
|
this.file_input_cats = path.join(dir_input, "categories.txt");
|
||||||
|
|
||||||
if(!fs.existsSync(this.dir_input_train))
|
if(!fs.existsSync(this.dir_input_train))
|
||||||
throw new Error(`Error: Failed to locate the directory containing the training data.`);
|
throw new Error(`Error: Failed to locate the directory containing the training data.`);
|
||||||
if(!fs.existsSync(this.dir_input_validate))
|
if(!fs.existsSync(this.dir_input_validate))
|
||||||
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
|
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
|
||||||
|
|
||||||
this.genres_count = genres.length;
|
this.cats = new Categories(this.file_input_cats);
|
||||||
}
|
|
||||||
|
|
||||||
genre2id(genre_name) {
|
|
||||||
return genres.indexOf(genre_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
id2genre(id) {
|
|
||||||
return genres[id];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async *data_from_dir(dirpath) {
|
async *data_from_dir(dirpath) {
|
||||||
|
@ -37,37 +30,54 @@ class DataPreprocessor {
|
||||||
|
|
||||||
let filenames = await fs.promises.readdir(dirpath);
|
let filenames = await fs.promises.readdir(dirpath);
|
||||||
for(let filename of filenames) {
|
for(let filename of filenames) {
|
||||||
let imagetensor = tf.node.decodeImage(
|
let imagebin = await fs.promises.readFile(path.join(dirpath, filename));
|
||||||
await fs.promises.readFile(path.join(dirpath, filename)),
|
let imagetensor = tf.tidy(() => {
|
||||||
|
// Normalise the data to be 0 - 1
|
||||||
|
return tf.node.decodeImage(
|
||||||
|
imagebin,
|
||||||
3 // channels
|
3 // channels
|
||||||
);
|
).cast("float32").div(255);
|
||||||
|
});
|
||||||
if(imagetensor.shape[0] == 256
|
if(imagetensor.shape[0] == 256
|
||||||
&& imagetensor.shape[1] == 256
|
&& imagetensor.shape[1] == 256
|
||||||
&& imagetensor.shape[0] == 3)
|
&& imagetensor.shape[0] == 3)
|
||||||
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
|
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
|
||||||
|
|
||||||
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
|
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
|
||||||
let next_genres = filename.replace(/\.[a-zA-Z]+$/, "")
|
let next_cats = filename.replace(/\.[a-zA-Z]+$/, "")
|
||||||
.split(",")
|
.split(",")
|
||||||
.slice(1)
|
.slice(1)
|
||||||
.map(this.genre2id);
|
.map(this.cats.to_id.bind(this.cats));
|
||||||
|
|
||||||
|
|
||||||
let genres_pretensor = Array(this.genres_count).fill(0);
|
let pretensor = Array(this.cats.count).fill(0);
|
||||||
for(let genre_id of next_genres)
|
for(let cat_id of next_cats)
|
||||||
genres_pretensor[genre_id] = 1;
|
pretensor[cat_id] = 1;
|
||||||
|
|
||||||
let genres_tensor = tf.tensor(genres_pretensor);
|
let cats_tensor = tf.tensor(pretensor);
|
||||||
|
|
||||||
// console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`);
|
// console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`);
|
||||||
|
|
||||||
yield {
|
let result = {
|
||||||
xs: imagetensor,
|
xs: imagetensor,
|
||||||
ys: genres_tensor
|
ys: cats_tensor
|
||||||
};
|
};
|
||||||
|
// console.log(`[DEBUG] yielding xs`, result.xs.dataSync());
|
||||||
|
// console.log(`[DEBUG] ys`); result.ys.print();
|
||||||
|
// console.log(`--------------------`);
|
||||||
|
yield result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a Categories object instance that represents the list of
|
||||||
|
* categories.
|
||||||
|
* @return {Categories}
|
||||||
|
*/
|
||||||
|
categories() {
|
||||||
|
return this.cats;
|
||||||
|
}
|
||||||
|
|
||||||
dataset_train() {
|
dataset_train() {
|
||||||
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
|
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,16 +7,20 @@ import tf from '@tensorflow/tfjs-node';
|
||||||
// import tf from '@tensorflow/tfjs-node-gpu';
|
// import tf from '@tensorflow/tfjs-node-gpu';
|
||||||
// import tf from '@tensorflow/tfjs';
|
// import tf from '@tensorflow/tfjs';
|
||||||
|
|
||||||
import genres from './Genres.mjs';
|
import make_model from './model.mjs';
|
||||||
|
|
||||||
class FilmPredictor {
|
class FilmPredictor {
|
||||||
constructor(settings) {
|
constructor(settings, in_categories) {
|
||||||
this.settings = settings;
|
this.settings = settings;
|
||||||
|
|
||||||
this.genres = genres.length;
|
this.cats = in_categories;
|
||||||
|
|
||||||
|
this.image_size = 256;
|
||||||
|
this.use_crossentropy = false;
|
||||||
|
this.activation = "relu";
|
||||||
|
|
||||||
this.batch_size = 32;
|
this.batch_size = 32;
|
||||||
this.prefetch = 4;
|
this.prefetch = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
async init(dirpath = null) {
|
async init(dirpath = null) {
|
||||||
|
@ -29,7 +33,17 @@ class FilmPredictor {
|
||||||
if(!fs.existsSync(this.dir_checkpoints))
|
if(!fs.existsSync(this.dir_checkpoints))
|
||||||
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
||||||
|
|
||||||
this.make_model();
|
await fs.promises.copyFile(
|
||||||
|
this.cats.filename,
|
||||||
|
path.join(this.settings.output, "categories.txt")
|
||||||
|
);
|
||||||
|
|
||||||
|
this.model = make_model(
|
||||||
|
this.cats.count,
|
||||||
|
this.image_size,
|
||||||
|
this.use_crossentropy,
|
||||||
|
this.activation
|
||||||
|
);
|
||||||
|
|
||||||
this.model.summary();
|
this.model.summary();
|
||||||
}
|
}
|
||||||
|
@ -43,57 +57,6 @@ class FilmPredictor {
|
||||||
console.error(`>>> Model loading complete`);
|
console.error(`>>> Model loading complete`);
|
||||||
}
|
}
|
||||||
|
|
||||||
make_model() {
|
|
||||||
console.error(`>>> Creating new model`);
|
|
||||||
this.model = tf.sequential();
|
|
||||||
|
|
||||||
this.model.add(tf.layers.conv2d({
|
|
||||||
name: "conv2d_1",
|
|
||||||
dataFormat: "channelsLast",
|
|
||||||
inputShape: [ 256, 256, 3 ],
|
|
||||||
kernelSize: 5,
|
|
||||||
filters: 3,
|
|
||||||
strides: 2,
|
|
||||||
activation: "relu"
|
|
||||||
}));
|
|
||||||
this.model.add(tf.layers.conv2d({
|
|
||||||
name: "conv2d_2",
|
|
||||||
dataFormat: "channelsLast",
|
|
||||||
kernelSize: 5,
|
|
||||||
filters: 3,
|
|
||||||
strides: 2,
|
|
||||||
activation: "relu"
|
|
||||||
}));
|
|
||||||
this.model.add(tf.layers.conv2d({
|
|
||||||
name: "conv2d_3",
|
|
||||||
dataFormat: "channelsLast",
|
|
||||||
kernelSize: 5,
|
|
||||||
filters: 3,
|
|
||||||
strides: 2,
|
|
||||||
activation: "relu"
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Reshape and flatten
|
|
||||||
let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape;
|
|
||||||
this.model.add(tf.layers.reshape({
|
|
||||||
name: "reshape",
|
|
||||||
targetShape: [
|
|
||||||
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
|
|
||||||
]
|
|
||||||
}));
|
|
||||||
this.model.add(tf.layers.dense({
|
|
||||||
name: "dense",
|
|
||||||
units: this.genres,
|
|
||||||
activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
|
|
||||||
}));
|
|
||||||
|
|
||||||
this.model.compile({
|
|
||||||
optimizer: tf.train.adam(),
|
|
||||||
loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
|
|
||||||
metrics: [ "mse" /* meanSquaredError */ ]
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async train(dataset_train, dataset_validate) {
|
async train(dataset_train, dataset_validate) {
|
||||||
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
|
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
|
||||||
|
@ -142,7 +105,7 @@ class FilmPredictor {
|
||||||
for(let i = 0; i < arr.length; i++) {
|
for(let i = 0; i < arr.length; i++) {
|
||||||
if(arr[i] < 0.5)
|
if(arr[i] < 0.5)
|
||||||
continue;
|
continue;
|
||||||
result.push(genres[i]);
|
result.push(this.cats.values[i]);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
67
src/lib/model.mjs
Normal file
67
src/lib/model.mjs
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
"use strict";
|
||||||
|
|
||||||
|
import tf from '@tensorflow/tfjs-node';
|
||||||
|
|
||||||
|
export default function(cats_count, image_size, use_crossentropy = false, activation = "relu") {
|
||||||
|
console.error(`>>> Creating new model (activation = ${activation})`);
|
||||||
|
let model = tf.sequential();
|
||||||
|
|
||||||
|
model.add(tf.layers.conv2d({
|
||||||
|
name: "conv2d_1",
|
||||||
|
dataFormat: "channelsLast",
|
||||||
|
inputShape: [ image_size, image_size, 3 ],
|
||||||
|
kernelSize: 5,
|
||||||
|
filters: 3,
|
||||||
|
strides: 1,
|
||||||
|
activation
|
||||||
|
}));
|
||||||
|
if(image_size > 32) {
|
||||||
|
model.add(tf.layers.conv2d({
|
||||||
|
name: "conv2d_2",
|
||||||
|
dataFormat: "channelsLast",
|
||||||
|
kernelSize: 5,
|
||||||
|
filters: 3,
|
||||||
|
strides: 2,
|
||||||
|
activation
|
||||||
|
}));
|
||||||
|
model.add(tf.layers.conv2d({
|
||||||
|
name: "conv2d_3",
|
||||||
|
dataFormat: "channelsLast",
|
||||||
|
kernelSize: 5,
|
||||||
|
filters: 3,
|
||||||
|
strides: 2,
|
||||||
|
activation
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape and flatten
|
||||||
|
let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape;
|
||||||
|
model.add(tf.layers.reshape({
|
||||||
|
name: "reshape",
|
||||||
|
targetShape: [
|
||||||
|
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
|
||||||
|
]
|
||||||
|
}));
|
||||||
|
model.add(tf.layers.dense({
|
||||||
|
name: "dense",
|
||||||
|
units: cats_count,
|
||||||
|
activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
|
||||||
|
}));
|
||||||
|
|
||||||
|
let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
|
||||||
|
metrics = [ "mse" /* meanSquaredError */ ];
|
||||||
|
|
||||||
|
if(use_crossentropy) {
|
||||||
|
console.log(`>>> Using categorical cross-entropy loss`);
|
||||||
|
loss = "categoricalCrossentropy";
|
||||||
|
metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ];
|
||||||
|
}
|
||||||
|
|
||||||
|
model.compile({
|
||||||
|
optimizer: tf.train.adam(),
|
||||||
|
loss,
|
||||||
|
metrics
|
||||||
|
});
|
||||||
|
|
||||||
|
return model;
|
||||||
|
}
|
|
@ -1,8 +1,10 @@
|
||||||
"use strict";
|
"use strict";
|
||||||
|
|
||||||
|
import path from 'path';
|
||||||
import fs from 'fs';
|
import fs from 'fs';
|
||||||
|
|
||||||
import FilmPredictor from '../../lib/FilmPredictor.mjs';
|
import FilmPredictor from '../../lib/FilmPredictor.mjs';
|
||||||
|
import Categories from '../../lib/Categories.mjs';
|
||||||
|
|
||||||
export default async function(settings) {
|
export default async function(settings) {
|
||||||
if(!fs.existsSync(settings.input)) {
|
if(!fs.existsSync(settings.input)) {
|
||||||
|
@ -14,7 +16,10 @@ export default async function(settings) {
|
||||||
process.exit(1);
|
process.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let model = new FilmPredictor(settings);
|
let model = new FilmPredictor(
|
||||||
|
settings,
|
||||||
|
new Categories(path.join(settings.ai_model, "categories.txt"))
|
||||||
|
);
|
||||||
await model.init(settings.ai_model); // We're training a new model here
|
await model.init(settings.ai_model); // We're training a new model here
|
||||||
|
|
||||||
let result = await model.predict(settings.input);
|
let result = await model.predict(settings.input);
|
||||||
|
|
|
@ -24,7 +24,10 @@ export default async function(settings) {
|
||||||
await fs.promises.mkdir(settings.output, { recursive: true, mode: 0o755 });
|
await fs.promises.mkdir(settings.output, { recursive: true, mode: 0o755 });
|
||||||
|
|
||||||
let preprocessor = new DataPreprocessor(settings.input);
|
let preprocessor = new DataPreprocessor(settings.input);
|
||||||
let model = new FilmPredictor(settings);
|
let model = new FilmPredictor(settings, preprocessor.cats);
|
||||||
|
model.image_size = settings.image_size;
|
||||||
|
model.use_crossentropy = settings.cross_entropy;
|
||||||
|
model.activation = settings.activation;
|
||||||
await model.init(); // We're training a new model here
|
await model.init(); // We're training a new model here
|
||||||
|
|
||||||
await model.train(
|
await model.train(
|
||||||
|
|
Loading…
Reference in a new issue