Add variable categories support

This commit is contained in:
Starbeamrainbowlabs 2020-12-15 18:22:50 +00:00
parent 0356c218f2
commit 00075b1823
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
10 changed files with 213 additions and 126 deletions

View file

@ -4,7 +4,7 @@
The example code & slide deck for a talk I gave on getting started with AI. A link to the unlisted YouTube video is available upon request (because it contains my face, and this is a public repo) - see [my website](https://starbeamrainbowlabs.com/) for ways to get in touch.
For advanced users, see also the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
This is the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
## Dataset
@ -16,6 +16,7 @@ The format is as follows:
```
+ dataset_dir/
+ categories.txt
+ train/
+ 1234,Comedy,Fantasy.jpg
+ 4321,Western,Short,Drama.jpg
@ -28,6 +29,9 @@ The format is as follows:
The filenames of the images take the following format: `ID,GENRE_1,GENRE_2,GENRE_N.jpg`.
The `categories.txt` file should contain 1 category name per line (the order matters, as the line numbers - starting from 0 - are used as the numerical ids when training the model).
## System / User Requirements
- [Node.js](https://nodejs.org/)
- [NPM](https://www.npmjs.com/) (installed by default with Node.js)

90
package-lock.json generated
View file

@ -5,16 +5,16 @@
"requires": true,
"dependencies": {
"@tensorflow/tfjs": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.7.0.tgz",
"integrity": "sha512-LTYK6+emFweYa3zn/o511JUR6s14/yGZpoXvFSUtdwolYHI+J50r/CyYeFpvtoTD7uwcNFQhbBAtp4L4e3Hsaw==",
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.8.0.tgz",
"integrity": "sha512-jIG+WqaBq3NsTU5SnJuxqfgkKLpjVPlU2dbVepKgEqVY7XdffiJEoLBslOSgEJQc0j3GskOmZRw4F9u3NqgB1g==",
"requires": {
"@tensorflow/tfjs-backend-cpu": "2.7.0",
"@tensorflow/tfjs-backend-webgl": "2.7.0",
"@tensorflow/tfjs-converter": "2.7.0",
"@tensorflow/tfjs-core": "2.7.0",
"@tensorflow/tfjs-data": "2.7.0",
"@tensorflow/tfjs-layers": "2.7.0",
"@tensorflow/tfjs-backend-cpu": "2.8.0",
"@tensorflow/tfjs-backend-webgl": "2.8.0",
"@tensorflow/tfjs-converter": "2.8.0",
"@tensorflow/tfjs-core": "2.8.0",
"@tensorflow/tfjs-data": "2.8.0",
"@tensorflow/tfjs-layers": "2.8.0",
"argparse": "^1.0.10",
"chalk": "^4.1.0",
"core-js": "3",
@ -23,9 +23,9 @@
}
},
"@tensorflow/tfjs-backend-cpu": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.7.0.tgz",
"integrity": "sha512-R6ORcWq3ub81ABvBZEZ8Ok5OOT59B4AsRe66ds7B/NK0nN+k6y37bR3ZDVjgkEKNWNvzB7ydODikge3GNmgQIQ==",
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.8.0.tgz",
"integrity": "sha512-RmseKyCWJR0Jz7HyRXm6VVjpaM9Rbql9Vr7Jrx4SCpYX8BzrRVxDYq9aMZqYh93qBj4xX5upoQ5y5zoaFjpbTw==",
"requires": {
"@types/seedrandom": "2.4.27",
"seedrandom": "2.4.3"
@ -39,11 +39,11 @@
}
},
"@tensorflow/tfjs-backend-webgl": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.7.0.tgz",
"integrity": "sha512-K7Rk5YTSWOZ969EZvh3w786daPn2ub4mA2JsX7mXKhBPUaOP9dKbBdLj9buCuMcu4zVq2pAp0QwpHSa4PHm3xg==",
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.8.0.tgz",
"integrity": "sha512-GZNpRyXQBHe3Tkt5m63HjgGyklfvU5kAoIuZe+jyfznzxkn3dhmiiLyMByQkxuc0LMZD1987riv0ALMSAXvVbQ==",
"requires": {
"@tensorflow/tfjs-backend-cpu": "2.7.0",
"@tensorflow/tfjs-backend-cpu": "2.8.0",
"@types/offscreencanvas": "~2019.3.0",
"@types/seedrandom": "2.4.27",
"@types/webgl-ext": "0.0.30",
@ -59,14 +59,14 @@
}
},
"@tensorflow/tfjs-converter": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.7.0.tgz",
"integrity": "sha512-SBpKYn/MkN8US7DeTcnvqHpvp/WKcwzpdgkQF+eHMHEbS1lXSlt4BHhOFgRdLPzy1gEC9+6P0VdTE8NQ737t/Q=="
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.8.0.tgz",
"integrity": "sha512-GvVzPu0+gtgfq0+UAA2mHV7z5TT359TAUUJHOI7/dDSlfCMfLx5bSa0CamW7PIPkzSn9urljXTbocdeNAdpJKQ=="
},
"@tensorflow/tfjs-core": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.7.0.tgz",
"integrity": "sha512-4w5zjK6C5nkLatHpzARVQNd5QKtIocJRwjZIwWcScT9z2z1dX4rVmDoUpYg1cdD4H+yRRdI0awRaI3SL34yy8Q==",
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.8.0.tgz",
"integrity": "sha512-MMU5wG9bRPccFjxfc5vORw8YkjfFrPUtnoPQ/1WOqud5l3z3a318WAGqzSloPUMOMPD0c7x0vPQBsUgteliC2w==",
"requires": {
"@types/offscreencanvas": "~2019.3.0",
"@types/seedrandom": "2.4.27",
@ -83,26 +83,26 @@
}
},
"@tensorflow/tfjs-data": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.7.0.tgz",
"integrity": "sha512-gsVklCwqqlxhykI7U2Uy5c2hjommQCAi+3y2/LER4TNtzQTzWaGKyIXvuLuL0tE896yuzXILIMZhkUjDmUiGxA==",
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.8.0.tgz",
"integrity": "sha512-U0UbL6Hffcq6O6r6Isb6cPM9GxXUcU998gcPSIBeTMaGnBGbm/tl6Nf7f8SbGV4076ojeeFr4+0ly8heyBnfew==",
"requires": {
"@types/node-fetch": "^2.1.2",
"node-fetch": "~2.6.1"
}
},
"@tensorflow/tfjs-layers": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.7.0.tgz",
"integrity": "sha512-78zsD2LLrHQuDYv0EeV83LiF0M69lKsBfuTB3FIBgS85gapZPyHh4wooKda2Y4H9EtLogU+C6bArZuDo8PaX+g=="
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.8.0.tgz",
"integrity": "sha512-g7ZEXyo46osVbJAR90KXbwWr82OlcqZOohSEIMXSZ/egnXLQpQqUBmGg8p8Nw2es6jmpSfOBDivv1z59HExVvg=="
},
"@tensorflow/tfjs-node": {
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.7.0.tgz",
"integrity": "sha512-0cWplm7AE40gi2llqoAp+lD/0X3dVJ8kb7Arrqb5lMhShRWUFZpULH+F0fJI6Yax4LBTzBi2SZKGL/O8krZsxg==",
"version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.8.0.tgz",
"integrity": "sha512-u62AqvIS9jaMrw09Hl/Dy1YKiqEkltQ1c+RoJuW0XE685bjRue7FRmkJy69+EvsnXLY6gvPR1y+LguTPql8ZwA==",
"requires": {
"@tensorflow/tfjs": "2.7.0",
"@tensorflow/tfjs-core": "2.7.0",
"@tensorflow/tfjs": "2.8.0",
"@tensorflow/tfjs-core": "2.8.0",
"adm-zip": "^0.4.11",
"google-protobuf": "^3.9.2",
"https-proxy-agent": "^2.2.1",
@ -113,9 +113,9 @@
}
},
"@types/node": {
"version": "14.14.7",
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.7.tgz",
"integrity": "sha512-Zw1vhUSQZYw+7u5dAwNbIA9TuTotpzY/OF7sJM9FqPOF3SPjKnxrjoTktXDZgUjybf4cWVBP7O8wvKdSaGHweg=="
"version": "14.14.14",
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.14.tgz",
"integrity": "sha512-UHnOPWVWV1z+VV8k6L1HhG7UbGBgIdghqF3l9Ny9ApPghbjICXkUJSd/b9gOgQfjM1r+37cipdw/HJ3F6ICEnQ=="
},
"@types/node-fetch": {
"version": "2.5.7",
@ -314,9 +314,9 @@
"integrity": "sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4="
},
"core-js": {
"version": "3.7.0",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.7.0.tgz",
"integrity": "sha512-NwS7fI5M5B85EwpWuIwJN4i/fbisQUwLwiSNUWeXlkAZ0sbBjLEvLvFLf1uzAUV66PcEPt4xCGCmOZSxVf3xzA=="
"version": "3.8.1",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.8.1.tgz",
"integrity": "sha512-9Id2xHY1W7m8hCl8NkhQn5CufmF/WuR30BTRewvCXc1aZd3kMECwNZ69ndLbekKfakw9Rf2Xyc+QR6E7Gg+obg=="
},
"core-util-is": {
"version": "1.0.2",
@ -485,9 +485,9 @@
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
},
"ini": {
"version": "1.3.5",
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.5.tgz",
"integrity": "sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw=="
"version": "1.3.8",
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
"integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
},
"is-fullwidth-code-point": {
"version": "1.0.0",
@ -885,9 +885,9 @@
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="
},
"yargs": {
"version": "16.1.1",
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.1.1.tgz",
"integrity": "sha512-hAD1RcFP/wfgfxgMVswPE+z3tlPFtxG8/yWUrG2i17sTWGCGqWnxKcLTF4cUKDUK8fzokwsmO9H0TDkRbMHy8w==",
"version": "16.2.0",
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.2.0.tgz",
"integrity": "sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw==",
"requires": {
"cliui": "^7.0.2",
"escalade": "^3.1.1",

View file

@ -20,7 +20,7 @@
"author": "Starbeamrainbowlabs",
"license": "MPL-2.0",
"dependencies": {
"@tensorflow/tfjs-node": "^2.7.0",
"@tensorflow/tfjs-node": "^2.8.0",
"applause-cli": "^1.5.1"
}
}

View file

@ -13,7 +13,9 @@ export default async function () {
let cli = new CliParser(path.resolve(__dirname, "../package.json"));
cli.subcommand("train", "Trains a new AI")
.argument("input", "The input directory containing the training data", null, "string")
.argument("output", "Path to the output directory to save the trained AI to");
.argument("output", "Path to the output directory to save the trained AI to")
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
cli.subcommand("predict", "Predicts the genres of the specified image")
.argument("input", "Path to the input image")
.argument("ai-model", "Path to the saved AI model to load");

40
src/lib/Categories.mjs Normal file
View file

@ -0,0 +1,40 @@
"use strict";
import fs from 'fs';
/**
* Represents a list of categories.
*/
class Categories {
constructor(in_filename) {
this.filename = in_filename;
if(!fs.existsSync(this.filename))
throw new Exception(`Error: No such file or directory '${this.filename}' (it should exist and have 1 category name per line).`);
/**
* A list of category names.
* @type {string[]}
*/
this.values = fs.readFileSync(this.filename, "utf-8")
.replace(/\r\n?/, "\n")
.split("\n")
.map((line) => line.trim())
.filter((line) => line.length > 0);
/**
* The total number of categories.
* @type {Number}
*/
this.count = this.values.length;
}
to_id(cat_name) {
return this.values.indexOf(cat_name);
}
to_name(id) {
return this.values[id];
}
}
export default Categories;

View file

@ -5,7 +5,7 @@ import path from 'path';
import tf from '@tensorflow/tfjs-node';
import genres from './Genres.mjs';
import Categories from './Categories.mjs';
class DataPreprocessor {
constructor(dir_input) {
@ -14,21 +14,14 @@ class DataPreprocessor {
this.dir_input_train = path.join(dir_input, "train");
this.dir_input_validate = path.join(dir_input, "validate");
this.file_input_cats = path.join(dir_input, "categories.txt");
if(!fs.existsSync(this.dir_input_train))
throw new Error(`Error: Failed to locate the directory containing the training data.`);
if(!fs.existsSync(this.dir_input_validate))
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
this.genres_count = genres.length;
}
genre2id(genre_name) {
return genres.indexOf(genre_name);
}
id2genre(id) {
return genres[id];
this.cats = new Categories(this.file_input_cats);
}
async *data_from_dir(dirpath) {
@ -47,27 +40,40 @@ class DataPreprocessor {
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
let next_genres = filename.replace(/\.[a-zA-Z]+$/, "")
let next_cats = filename.replace(/\.[a-zA-Z]+$/, "")
.split(",")
.slice(1)
.map(this.genre2id);
.map(this.cats.to_id.bind(this.cats));
let genres_pretensor = Array(this.genres_count).fill(0);
for(let genre_id of next_genres)
genres_pretensor[genre_id] = 1;
let pretensor = Array(this.cats.count).fill(0);
for(let cat_id of next_cats)
pretensor[cat_id] = 1;
let genres_tensor = tf.tensor(genres_pretensor);
let cats_tensor = tf.tensor(pretensor);
// console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`);
// console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`);
yield {
let result = {
xs: imagetensor,
ys: genres_tensor
ys: cats_tensor
};
// console.log(`[DEBUG] yielding xs`, result.xs.dataSync());
// console.log(`[DEBUG] ys`); result.ys.print();
// console.log(`--------------------`);
yield result;
}
}
/**
* Returns a Categories object instance that represents the list of
* categories.
* @return {Categories}
*/
categories() {
return this.cats;
}
dataset_train() {
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
}

View file

@ -7,16 +7,19 @@ import tf from '@tensorflow/tfjs-node';
// import tf from '@tensorflow/tfjs-node-gpu';
// import tf from '@tensorflow/tfjs';
import genres from './Genres.mjs';
import make_model from './model.mjs';
class FilmPredictor {
constructor(settings) {
constructor(settings, in_categories) {
this.settings = settings;
this.genres = genres.length;
this.cats = in_categories;
this.image_size = 256;
this.use_crossentropy = false;
this.batch_size = 32;
this.prefetch = 4;
this.prefetch = 64;
}
async init(dirpath = null) {
@ -29,7 +32,16 @@ class FilmPredictor {
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
this.make_model();
await fs.promises.copyFile(
this.cats.filename,
path.join(this.settings.output, "categories.txt")
);
this.model = make_model(
this.cats.count,
this.image_size,
this.use_crossentropy
);
this.model.summary();
}
@ -43,57 +55,6 @@ class FilmPredictor {
console.error(`>>> Model loading complete`);
}
make_model() {
console.error(`>>> Creating new model`);
this.model = tf.sequential();
this.model.add(tf.layers.conv2d({
name: "conv2d_1",
dataFormat: "channelsLast",
inputShape: [ 256, 256, 3 ],
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
this.model.add(tf.layers.conv2d({
name: "conv2d_2",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
this.model.add(tf.layers.conv2d({
name: "conv2d_3",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
// Reshape and flatten
let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape;
this.model.add(tf.layers.reshape({
name: "reshape",
targetShape: [
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
]
}));
this.model.add(tf.layers.dense({
name: "dense",
units: this.genres,
activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
}));
this.model.compile({
optimizer: tf.train.adam(),
loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics: [ "mse" /* meanSquaredError */ ]
});
}
async train(dataset_train, dataset_validate) {
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
@ -142,7 +103,7 @@ class FilmPredictor {
for(let i = 0; i < arr.length; i++) {
if(arr[i] < 0.5)
continue;
result.push(genres[i]);
result.push(this.cats.values[i]);
}
return result;
}

67
src/lib/model.mjs Normal file
View file

@ -0,0 +1,67 @@
"use strict";
import tf from '@tensorflow/tfjs-node';
export default function(cats_count, image_size, use_crossentropy = false) {
console.error(`>>> Creating new model`);
let model = tf.sequential();
model.add(tf.layers.conv2d({
name: "conv2d_1",
dataFormat: "channelsLast",
inputShape: [ image_size, image_size, 3 ],
kernelSize: 5,
filters: 3,
strides: 1,
activation: "relu"
}));
if(image_size > 32) {
model.add(tf.layers.conv2d({
name: "conv2d_2",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
model.add(tf.layers.conv2d({
name: "conv2d_3",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
}
// Reshape and flatten
let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape;
model.add(tf.layers.reshape({
name: "reshape",
targetShape: [
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
]
}));
model.add(tf.layers.dense({
name: "dense",
units: cats_count,
activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
}));
let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics = [ "mse" /* meanSquaredError */ ];
if(use_crossentropy) {
console.log(`>>> Using categorical cross-entropy loss`);
loss = "categoricalCrossentropy";
metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ];
}
model.compile({
optimizer: tf.train.adam(),
loss,
metrics
});
return model;
}

View file

@ -1,8 +1,10 @@
"use strict";
import path from 'path';
import fs from 'fs';
import FilmPredictor from '../../lib/FilmPredictor.mjs';
import Categories from '../../lib/Categories.mjs';
export default async function(settings) {
if(!fs.existsSync(settings.input)) {
@ -14,7 +16,10 @@ export default async function(settings) {
process.exit(1);
}
let model = new FilmPredictor(settings);
let model = new FilmPredictor(
settings,
new Categories(path.join(settings.ai_model, "categories.txt"))
);
await model.init(settings.ai_model); // We're training a new model here
let result = await model.predict(settings.input);

View file

@ -24,7 +24,9 @@ export default async function(settings) {
await fs.promises.mkdir(settings.output, { recursive: true, mode: 0o755 });
let preprocessor = new DataPreprocessor(settings.input);
let model = new FilmPredictor(settings);
let model = new FilmPredictor(settings, preprocessor.cats);
model.image_size = settings.image_size;
model.use_crossentropy = settings.cross_entropy;
await model.init(); // We're training a new model here
await model.train(