Add variable categories support

This commit is contained in:
Starbeamrainbowlabs 2020-12-15 18:22:50 +00:00
parent 0356c218f2
commit 00075b1823
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
10 changed files with 213 additions and 126 deletions

View file

@ -4,7 +4,7 @@
The example code & slide deck for a talk I gave on getting started with AI. A link to the unlisted YouTube video is available upon request (because it contains my face, and this is a public repo) - see [my website](https://starbeamrainbowlabs.com/) for ways to get in touch. The example code & slide deck for a talk I gave on getting started with AI. A link to the unlisted YouTube video is available upon request (because it contains my face, and this is a public repo) - see [my website](https://starbeamrainbowlabs.com/) for ways to get in touch.
For advanced users, see also the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase. This is the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
## Dataset ## Dataset
@ -16,6 +16,7 @@ The format is as follows:
``` ```
+ dataset_dir/ + dataset_dir/
+ categories.txt
+ train/ + train/
+ 1234,Comedy,Fantasy.jpg + 1234,Comedy,Fantasy.jpg
+ 4321,Western,Short,Drama.jpg + 4321,Western,Short,Drama.jpg
@ -28,6 +29,9 @@ The format is as follows:
The filenames of the images take the following format: `ID,GENRE_1,GENRE_2,GENRE_N.jpg`. The filenames of the images take the following format: `ID,GENRE_1,GENRE_2,GENRE_N.jpg`.
The `categories.txt` file should contain 1 category name per line (the order matters, as the line numbers - starting from 0 - are used as the numerical ids when training the model).
## System / User Requirements ## System / User Requirements
- [Node.js](https://nodejs.org/) - [Node.js](https://nodejs.org/)
- [NPM](https://www.npmjs.com/) (installed by default with Node.js) - [NPM](https://www.npmjs.com/) (installed by default with Node.js)

90
package-lock.json generated
View file

@ -5,16 +5,16 @@
"requires": true, "requires": true,
"dependencies": { "dependencies": {
"@tensorflow/tfjs": { "@tensorflow/tfjs": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.8.0.tgz",
"integrity": "sha512-LTYK6+emFweYa3zn/o511JUR6s14/yGZpoXvFSUtdwolYHI+J50r/CyYeFpvtoTD7uwcNFQhbBAtp4L4e3Hsaw==", "integrity": "sha512-jIG+WqaBq3NsTU5SnJuxqfgkKLpjVPlU2dbVepKgEqVY7XdffiJEoLBslOSgEJQc0j3GskOmZRw4F9u3NqgB1g==",
"requires": { "requires": {
"@tensorflow/tfjs-backend-cpu": "2.7.0", "@tensorflow/tfjs-backend-cpu": "2.8.0",
"@tensorflow/tfjs-backend-webgl": "2.7.0", "@tensorflow/tfjs-backend-webgl": "2.8.0",
"@tensorflow/tfjs-converter": "2.7.0", "@tensorflow/tfjs-converter": "2.8.0",
"@tensorflow/tfjs-core": "2.7.0", "@tensorflow/tfjs-core": "2.8.0",
"@tensorflow/tfjs-data": "2.7.0", "@tensorflow/tfjs-data": "2.8.0",
"@tensorflow/tfjs-layers": "2.7.0", "@tensorflow/tfjs-layers": "2.8.0",
"argparse": "^1.0.10", "argparse": "^1.0.10",
"chalk": "^4.1.0", "chalk": "^4.1.0",
"core-js": "3", "core-js": "3",
@ -23,9 +23,9 @@
} }
}, },
"@tensorflow/tfjs-backend-cpu": { "@tensorflow/tfjs-backend-cpu": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.8.0.tgz",
"integrity": "sha512-R6ORcWq3ub81ABvBZEZ8Ok5OOT59B4AsRe66ds7B/NK0nN+k6y37bR3ZDVjgkEKNWNvzB7ydODikge3GNmgQIQ==", "integrity": "sha512-RmseKyCWJR0Jz7HyRXm6VVjpaM9Rbql9Vr7Jrx4SCpYX8BzrRVxDYq9aMZqYh93qBj4xX5upoQ5y5zoaFjpbTw==",
"requires": { "requires": {
"@types/seedrandom": "2.4.27", "@types/seedrandom": "2.4.27",
"seedrandom": "2.4.3" "seedrandom": "2.4.3"
@ -39,11 +39,11 @@
} }
}, },
"@tensorflow/tfjs-backend-webgl": { "@tensorflow/tfjs-backend-webgl": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.8.0.tgz",
"integrity": "sha512-K7Rk5YTSWOZ969EZvh3w786daPn2ub4mA2JsX7mXKhBPUaOP9dKbBdLj9buCuMcu4zVq2pAp0QwpHSa4PHm3xg==", "integrity": "sha512-GZNpRyXQBHe3Tkt5m63HjgGyklfvU5kAoIuZe+jyfznzxkn3dhmiiLyMByQkxuc0LMZD1987riv0ALMSAXvVbQ==",
"requires": { "requires": {
"@tensorflow/tfjs-backend-cpu": "2.7.0", "@tensorflow/tfjs-backend-cpu": "2.8.0",
"@types/offscreencanvas": "~2019.3.0", "@types/offscreencanvas": "~2019.3.0",
"@types/seedrandom": "2.4.27", "@types/seedrandom": "2.4.27",
"@types/webgl-ext": "0.0.30", "@types/webgl-ext": "0.0.30",
@ -59,14 +59,14 @@
} }
}, },
"@tensorflow/tfjs-converter": { "@tensorflow/tfjs-converter": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.8.0.tgz",
"integrity": "sha512-SBpKYn/MkN8US7DeTcnvqHpvp/WKcwzpdgkQF+eHMHEbS1lXSlt4BHhOFgRdLPzy1gEC9+6P0VdTE8NQ737t/Q==" "integrity": "sha512-GvVzPu0+gtgfq0+UAA2mHV7z5TT359TAUUJHOI7/dDSlfCMfLx5bSa0CamW7PIPkzSn9urljXTbocdeNAdpJKQ=="
}, },
"@tensorflow/tfjs-core": { "@tensorflow/tfjs-core": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.8.0.tgz",
"integrity": "sha512-4w5zjK6C5nkLatHpzARVQNd5QKtIocJRwjZIwWcScT9z2z1dX4rVmDoUpYg1cdD4H+yRRdI0awRaI3SL34yy8Q==", "integrity": "sha512-MMU5wG9bRPccFjxfc5vORw8YkjfFrPUtnoPQ/1WOqud5l3z3a318WAGqzSloPUMOMPD0c7x0vPQBsUgteliC2w==",
"requires": { "requires": {
"@types/offscreencanvas": "~2019.3.0", "@types/offscreencanvas": "~2019.3.0",
"@types/seedrandom": "2.4.27", "@types/seedrandom": "2.4.27",
@ -83,26 +83,26 @@
} }
}, },
"@tensorflow/tfjs-data": { "@tensorflow/tfjs-data": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.8.0.tgz",
"integrity": "sha512-gsVklCwqqlxhykI7U2Uy5c2hjommQCAi+3y2/LER4TNtzQTzWaGKyIXvuLuL0tE896yuzXILIMZhkUjDmUiGxA==", "integrity": "sha512-U0UbL6Hffcq6O6r6Isb6cPM9GxXUcU998gcPSIBeTMaGnBGbm/tl6Nf7f8SbGV4076ojeeFr4+0ly8heyBnfew==",
"requires": { "requires": {
"@types/node-fetch": "^2.1.2", "@types/node-fetch": "^2.1.2",
"node-fetch": "~2.6.1" "node-fetch": "~2.6.1"
} }
}, },
"@tensorflow/tfjs-layers": { "@tensorflow/tfjs-layers": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.8.0.tgz",
"integrity": "sha512-78zsD2LLrHQuDYv0EeV83LiF0M69lKsBfuTB3FIBgS85gapZPyHh4wooKda2Y4H9EtLogU+C6bArZuDo8PaX+g==" "integrity": "sha512-g7ZEXyo46osVbJAR90KXbwWr82OlcqZOohSEIMXSZ/egnXLQpQqUBmGg8p8Nw2es6jmpSfOBDivv1z59HExVvg=="
}, },
"@tensorflow/tfjs-node": { "@tensorflow/tfjs-node": {
"version": "2.7.0", "version": "2.8.0",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.7.0.tgz", "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.8.0.tgz",
"integrity": "sha512-0cWplm7AE40gi2llqoAp+lD/0X3dVJ8kb7Arrqb5lMhShRWUFZpULH+F0fJI6Yax4LBTzBi2SZKGL/O8krZsxg==", "integrity": "sha512-u62AqvIS9jaMrw09Hl/Dy1YKiqEkltQ1c+RoJuW0XE685bjRue7FRmkJy69+EvsnXLY6gvPR1y+LguTPql8ZwA==",
"requires": { "requires": {
"@tensorflow/tfjs": "2.7.0", "@tensorflow/tfjs": "2.8.0",
"@tensorflow/tfjs-core": "2.7.0", "@tensorflow/tfjs-core": "2.8.0",
"adm-zip": "^0.4.11", "adm-zip": "^0.4.11",
"google-protobuf": "^3.9.2", "google-protobuf": "^3.9.2",
"https-proxy-agent": "^2.2.1", "https-proxy-agent": "^2.2.1",
@ -113,9 +113,9 @@
} }
}, },
"@types/node": { "@types/node": {
"version": "14.14.7", "version": "14.14.14",
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.7.tgz", "resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.14.tgz",
"integrity": "sha512-Zw1vhUSQZYw+7u5dAwNbIA9TuTotpzY/OF7sJM9FqPOF3SPjKnxrjoTktXDZgUjybf4cWVBP7O8wvKdSaGHweg==" "integrity": "sha512-UHnOPWVWV1z+VV8k6L1HhG7UbGBgIdghqF3l9Ny9ApPghbjICXkUJSd/b9gOgQfjM1r+37cipdw/HJ3F6ICEnQ=="
}, },
"@types/node-fetch": { "@types/node-fetch": {
"version": "2.5.7", "version": "2.5.7",
@ -314,9 +314,9 @@
"integrity": "sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4=" "integrity": "sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4="
}, },
"core-js": { "core-js": {
"version": "3.7.0", "version": "3.8.1",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.7.0.tgz", "resolved": "https://registry.npmjs.org/core-js/-/core-js-3.8.1.tgz",
"integrity": "sha512-NwS7fI5M5B85EwpWuIwJN4i/fbisQUwLwiSNUWeXlkAZ0sbBjLEvLvFLf1uzAUV66PcEPt4xCGCmOZSxVf3xzA==" "integrity": "sha512-9Id2xHY1W7m8hCl8NkhQn5CufmF/WuR30BTRewvCXc1aZd3kMECwNZ69ndLbekKfakw9Rf2Xyc+QR6E7Gg+obg=="
}, },
"core-util-is": { "core-util-is": {
"version": "1.0.2", "version": "1.0.2",
@ -485,9 +485,9 @@
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==" "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
}, },
"ini": { "ini": {
"version": "1.3.5", "version": "1.3.8",
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.5.tgz", "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
"integrity": "sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw==" "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
}, },
"is-fullwidth-code-point": { "is-fullwidth-code-point": {
"version": "1.0.0", "version": "1.0.0",
@ -885,9 +885,9 @@
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==" "integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="
}, },
"yargs": { "yargs": {
"version": "16.1.1", "version": "16.2.0",
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.1.1.tgz", "resolved": "https://registry.npmjs.org/yargs/-/yargs-16.2.0.tgz",
"integrity": "sha512-hAD1RcFP/wfgfxgMVswPE+z3tlPFtxG8/yWUrG2i17sTWGCGqWnxKcLTF4cUKDUK8fzokwsmO9H0TDkRbMHy8w==", "integrity": "sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw==",
"requires": { "requires": {
"cliui": "^7.0.2", "cliui": "^7.0.2",
"escalade": "^3.1.1", "escalade": "^3.1.1",

View file

@ -20,7 +20,7 @@
"author": "Starbeamrainbowlabs", "author": "Starbeamrainbowlabs",
"license": "MPL-2.0", "license": "MPL-2.0",
"dependencies": { "dependencies": {
"@tensorflow/tfjs-node": "^2.7.0", "@tensorflow/tfjs-node": "^2.8.0",
"applause-cli": "^1.5.1" "applause-cli": "^1.5.1"
} }
} }

View file

@ -13,7 +13,9 @@ export default async function () {
let cli = new CliParser(path.resolve(__dirname, "../package.json")); let cli = new CliParser(path.resolve(__dirname, "../package.json"));
cli.subcommand("train", "Trains a new AI") cli.subcommand("train", "Trains a new AI")
.argument("input", "The input directory containing the training data", null, "string") .argument("input", "The input directory containing the training data", null, "string")
.argument("output", "Path to the output directory to save the trained AI to"); .argument("output", "Path to the output directory to save the trained AI to")
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
cli.subcommand("predict", "Predicts the genres of the specified image") cli.subcommand("predict", "Predicts the genres of the specified image")
.argument("input", "Path to the input image") .argument("input", "Path to the input image")
.argument("ai-model", "Path to the saved AI model to load"); .argument("ai-model", "Path to the saved AI model to load");

40
src/lib/Categories.mjs Normal file
View file

@ -0,0 +1,40 @@
"use strict";
import fs from 'fs';
/**
* Represents a list of categories.
*/
class Categories {
constructor(in_filename) {
this.filename = in_filename;
if(!fs.existsSync(this.filename))
throw new Exception(`Error: No such file or directory '${this.filename}' (it should exist and have 1 category name per line).`);
/**
* A list of category names.
* @type {string[]}
*/
this.values = fs.readFileSync(this.filename, "utf-8")
.replace(/\r\n?/, "\n")
.split("\n")
.map((line) => line.trim())
.filter((line) => line.length > 0);
/**
* The total number of categories.
* @type {Number}
*/
this.count = this.values.length;
}
to_id(cat_name) {
return this.values.indexOf(cat_name);
}
to_name(id) {
return this.values[id];
}
}
export default Categories;

View file

@ -5,7 +5,7 @@ import path from 'path';
import tf from '@tensorflow/tfjs-node'; import tf from '@tensorflow/tfjs-node';
import genres from './Genres.mjs'; import Categories from './Categories.mjs';
class DataPreprocessor { class DataPreprocessor {
constructor(dir_input) { constructor(dir_input) {
@ -14,21 +14,14 @@ class DataPreprocessor {
this.dir_input_train = path.join(dir_input, "train"); this.dir_input_train = path.join(dir_input, "train");
this.dir_input_validate = path.join(dir_input, "validate"); this.dir_input_validate = path.join(dir_input, "validate");
this.file_input_cats = path.join(dir_input, "categories.txt");
if(!fs.existsSync(this.dir_input_train)) if(!fs.existsSync(this.dir_input_train))
throw new Error(`Error: Failed to locate the directory containing the training data.`); throw new Error(`Error: Failed to locate the directory containing the training data.`);
if(!fs.existsSync(this.dir_input_validate)) if(!fs.existsSync(this.dir_input_validate))
throw new Error(`Error: Failed to locate the directory containing the validation data.`); throw new Error(`Error: Failed to locate the directory containing the validation data.`);
this.genres_count = genres.length; this.cats = new Categories(this.file_input_cats);
}
genre2id(genre_name) {
return genres.indexOf(genre_name);
}
id2genre(id) {
return genres[id];
} }
async *data_from_dir(dirpath) { async *data_from_dir(dirpath) {
@ -47,27 +40,40 @@ class DataPreprocessor {
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`); throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning // Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
let next_genres = filename.replace(/\.[a-zA-Z]+$/, "") let next_cats = filename.replace(/\.[a-zA-Z]+$/, "")
.split(",") .split(",")
.slice(1) .slice(1)
.map(this.genre2id); .map(this.cats.to_id.bind(this.cats));
let genres_pretensor = Array(this.genres_count).fill(0); let pretensor = Array(this.cats.count).fill(0);
for(let genre_id of next_genres) for(let cat_id of next_cats)
genres_pretensor[genre_id] = 1; pretensor[cat_id] = 1;
let genres_tensor = tf.tensor(genres_pretensor); let cats_tensor = tf.tensor(pretensor);
// console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`); // console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`);
yield { let result = {
xs: imagetensor, xs: imagetensor,
ys: genres_tensor ys: cats_tensor
}; };
// console.log(`[DEBUG] yielding xs`, result.xs.dataSync());
// console.log(`[DEBUG] ys`); result.ys.print();
// console.log(`--------------------`);
yield result;
} }
} }
/**
* Returns a Categories object instance that represents the list of
* categories.
* @return {Categories}
*/
categories() {
return this.cats;
}
dataset_train() { dataset_train() {
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train)); return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
} }

View file

@ -7,16 +7,19 @@ import tf from '@tensorflow/tfjs-node';
// import tf from '@tensorflow/tfjs-node-gpu'; // import tf from '@tensorflow/tfjs-node-gpu';
// import tf from '@tensorflow/tfjs'; // import tf from '@tensorflow/tfjs';
import genres from './Genres.mjs'; import make_model from './model.mjs';
class FilmPredictor { class FilmPredictor {
constructor(settings) { constructor(settings, in_categories) {
this.settings = settings; this.settings = settings;
this.genres = genres.length; this.cats = in_categories;
this.image_size = 256;
this.use_crossentropy = false;
this.batch_size = 32; this.batch_size = 32;
this.prefetch = 4; this.prefetch = 64;
} }
async init(dirpath = null) { async init(dirpath = null) {
@ -29,7 +32,16 @@ class FilmPredictor {
if(!fs.existsSync(this.dir_checkpoints)) if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 }); await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
this.make_model(); await fs.promises.copyFile(
this.cats.filename,
path.join(this.settings.output, "categories.txt")
);
this.model = make_model(
this.cats.count,
this.image_size,
this.use_crossentropy
);
this.model.summary(); this.model.summary();
} }
@ -43,57 +55,6 @@ class FilmPredictor {
console.error(`>>> Model loading complete`); console.error(`>>> Model loading complete`);
} }
make_model() {
console.error(`>>> Creating new model`);
this.model = tf.sequential();
this.model.add(tf.layers.conv2d({
name: "conv2d_1",
dataFormat: "channelsLast",
inputShape: [ 256, 256, 3 ],
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
this.model.add(tf.layers.conv2d({
name: "conv2d_2",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
this.model.add(tf.layers.conv2d({
name: "conv2d_3",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
// Reshape and flatten
let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape;
this.model.add(tf.layers.reshape({
name: "reshape",
targetShape: [
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
]
}));
this.model.add(tf.layers.dense({
name: "dense",
units: this.genres,
activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
}));
this.model.compile({
optimizer: tf.train.adam(),
loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics: [ "mse" /* meanSquaredError */ ]
});
}
async train(dataset_train, dataset_validate) { async train(dataset_train, dataset_validate) {
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch); dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
@ -142,7 +103,7 @@ class FilmPredictor {
for(let i = 0; i < arr.length; i++) { for(let i = 0; i < arr.length; i++) {
if(arr[i] < 0.5) if(arr[i] < 0.5)
continue; continue;
result.push(genres[i]); result.push(this.cats.values[i]);
} }
return result; return result;
} }

67
src/lib/model.mjs Normal file
View file

@ -0,0 +1,67 @@
"use strict";
import tf from '@tensorflow/tfjs-node';
export default function(cats_count, image_size, use_crossentropy = false) {
console.error(`>>> Creating new model`);
let model = tf.sequential();
model.add(tf.layers.conv2d({
name: "conv2d_1",
dataFormat: "channelsLast",
inputShape: [ image_size, image_size, 3 ],
kernelSize: 5,
filters: 3,
strides: 1,
activation: "relu"
}));
if(image_size > 32) {
model.add(tf.layers.conv2d({
name: "conv2d_2",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
model.add(tf.layers.conv2d({
name: "conv2d_3",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
}
// Reshape and flatten
let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape;
model.add(tf.layers.reshape({
name: "reshape",
targetShape: [
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
]
}));
model.add(tf.layers.dense({
name: "dense",
units: cats_count,
activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
}));
let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics = [ "mse" /* meanSquaredError */ ];
if(use_crossentropy) {
console.log(`>>> Using categorical cross-entropy loss`);
loss = "categoricalCrossentropy";
metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ];
}
model.compile({
optimizer: tf.train.adam(),
loss,
metrics
});
return model;
}

View file

@ -1,8 +1,10 @@
"use strict"; "use strict";
import path from 'path';
import fs from 'fs'; import fs from 'fs';
import FilmPredictor from '../../lib/FilmPredictor.mjs'; import FilmPredictor from '../../lib/FilmPredictor.mjs';
import Categories from '../../lib/Categories.mjs';
export default async function(settings) { export default async function(settings) {
if(!fs.existsSync(settings.input)) { if(!fs.existsSync(settings.input)) {
@ -14,7 +16,10 @@ export default async function(settings) {
process.exit(1); process.exit(1);
} }
let model = new FilmPredictor(settings); let model = new FilmPredictor(
settings,
new Categories(path.join(settings.ai_model, "categories.txt"))
);
await model.init(settings.ai_model); // We're training a new model here await model.init(settings.ai_model); // We're training a new model here
let result = await model.predict(settings.input); let result = await model.predict(settings.input);

View file

@ -24,7 +24,9 @@ export default async function(settings) {
await fs.promises.mkdir(settings.output, { recursive: true, mode: 0o755 }); await fs.promises.mkdir(settings.output, { recursive: true, mode: 0o755 });
let preprocessor = new DataPreprocessor(settings.input); let preprocessor = new DataPreprocessor(settings.input);
let model = new FilmPredictor(settings); let model = new FilmPredictor(settings, preprocessor.cats);
model.image_size = settings.image_size;
model.use_crossentropy = settings.cross_entropy;
await model.init(); // We're training a new model here await model.init(); // We're training a new model here
await model.train( await model.train(