Add variable categories support
This commit is contained in:
parent
0356c218f2
commit
00075b1823
10 changed files with 213 additions and 126 deletions
|
@ -4,7 +4,7 @@
|
|||
|
||||
The example code & slide deck for a talk I gave on getting started with AI. A link to the unlisted YouTube video is available upon request (because it contains my face, and this is a public repo) - see [my website](https://starbeamrainbowlabs.com/) for ways to get in touch.
|
||||
|
||||
For advanced users, see also the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
|
||||
This is the `post-talk-improvements` branch, which contains a number of adjustments and changes to this codebase. These changes improve the model, but make it more complex to understand and hence are kept in a separate branch to allow regular users to look at the simpler / easier to understand codebase.
|
||||
|
||||
|
||||
## Dataset
|
||||
|
@ -16,6 +16,7 @@ The format is as follows:
|
|||
|
||||
```
|
||||
+ dataset_dir/
|
||||
+ categories.txt
|
||||
+ train/
|
||||
+ 1234,Comedy,Fantasy.jpg
|
||||
+ 4321,Western,Short,Drama.jpg
|
||||
|
@ -28,6 +29,9 @@ The format is as follows:
|
|||
|
||||
The filenames of the images take the following format: `ID,GENRE_1,GENRE_2,GENRE_N.jpg`.
|
||||
|
||||
The `categories.txt` file should contain 1 category name per line (the order matters, as the line numbers - starting from 0 - are used as the numerical ids when training the model).
|
||||
|
||||
|
||||
## System / User Requirements
|
||||
- [Node.js](https://nodejs.org/)
|
||||
- [NPM](https://www.npmjs.com/) (installed by default with Node.js)
|
||||
|
|
90
package-lock.json
generated
90
package-lock.json
generated
|
@ -5,16 +5,16 @@
|
|||
"requires": true,
|
||||
"dependencies": {
|
||||
"@tensorflow/tfjs": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.7.0.tgz",
|
||||
"integrity": "sha512-LTYK6+emFweYa3zn/o511JUR6s14/yGZpoXvFSUtdwolYHI+J50r/CyYeFpvtoTD7uwcNFQhbBAtp4L4e3Hsaw==",
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-2.8.0.tgz",
|
||||
"integrity": "sha512-jIG+WqaBq3NsTU5SnJuxqfgkKLpjVPlU2dbVepKgEqVY7XdffiJEoLBslOSgEJQc0j3GskOmZRw4F9u3NqgB1g==",
|
||||
"requires": {
|
||||
"@tensorflow/tfjs-backend-cpu": "2.7.0",
|
||||
"@tensorflow/tfjs-backend-webgl": "2.7.0",
|
||||
"@tensorflow/tfjs-converter": "2.7.0",
|
||||
"@tensorflow/tfjs-core": "2.7.0",
|
||||
"@tensorflow/tfjs-data": "2.7.0",
|
||||
"@tensorflow/tfjs-layers": "2.7.0",
|
||||
"@tensorflow/tfjs-backend-cpu": "2.8.0",
|
||||
"@tensorflow/tfjs-backend-webgl": "2.8.0",
|
||||
"@tensorflow/tfjs-converter": "2.8.0",
|
||||
"@tensorflow/tfjs-core": "2.8.0",
|
||||
"@tensorflow/tfjs-data": "2.8.0",
|
||||
"@tensorflow/tfjs-layers": "2.8.0",
|
||||
"argparse": "^1.0.10",
|
||||
"chalk": "^4.1.0",
|
||||
"core-js": "3",
|
||||
|
@ -23,9 +23,9 @@
|
|||
}
|
||||
},
|
||||
"@tensorflow/tfjs-backend-cpu": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.7.0.tgz",
|
||||
"integrity": "sha512-R6ORcWq3ub81ABvBZEZ8Ok5OOT59B4AsRe66ds7B/NK0nN+k6y37bR3ZDVjgkEKNWNvzB7ydODikge3GNmgQIQ==",
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.8.0.tgz",
|
||||
"integrity": "sha512-RmseKyCWJR0Jz7HyRXm6VVjpaM9Rbql9Vr7Jrx4SCpYX8BzrRVxDYq9aMZqYh93qBj4xX5upoQ5y5zoaFjpbTw==",
|
||||
"requires": {
|
||||
"@types/seedrandom": "2.4.27",
|
||||
"seedrandom": "2.4.3"
|
||||
|
@ -39,11 +39,11 @@
|
|||
}
|
||||
},
|
||||
"@tensorflow/tfjs-backend-webgl": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.7.0.tgz",
|
||||
"integrity": "sha512-K7Rk5YTSWOZ969EZvh3w786daPn2ub4mA2JsX7mXKhBPUaOP9dKbBdLj9buCuMcu4zVq2pAp0QwpHSa4PHm3xg==",
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-2.8.0.tgz",
|
||||
"integrity": "sha512-GZNpRyXQBHe3Tkt5m63HjgGyklfvU5kAoIuZe+jyfznzxkn3dhmiiLyMByQkxuc0LMZD1987riv0ALMSAXvVbQ==",
|
||||
"requires": {
|
||||
"@tensorflow/tfjs-backend-cpu": "2.7.0",
|
||||
"@tensorflow/tfjs-backend-cpu": "2.8.0",
|
||||
"@types/offscreencanvas": "~2019.3.0",
|
||||
"@types/seedrandom": "2.4.27",
|
||||
"@types/webgl-ext": "0.0.30",
|
||||
|
@ -59,14 +59,14 @@
|
|||
}
|
||||
},
|
||||
"@tensorflow/tfjs-converter": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.7.0.tgz",
|
||||
"integrity": "sha512-SBpKYn/MkN8US7DeTcnvqHpvp/WKcwzpdgkQF+eHMHEbS1lXSlt4BHhOFgRdLPzy1gEC9+6P0VdTE8NQ737t/Q=="
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-converter/-/tfjs-converter-2.8.0.tgz",
|
||||
"integrity": "sha512-GvVzPu0+gtgfq0+UAA2mHV7z5TT359TAUUJHOI7/dDSlfCMfLx5bSa0CamW7PIPkzSn9urljXTbocdeNAdpJKQ=="
|
||||
},
|
||||
"@tensorflow/tfjs-core": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.7.0.tgz",
|
||||
"integrity": "sha512-4w5zjK6C5nkLatHpzARVQNd5QKtIocJRwjZIwWcScT9z2z1dX4rVmDoUpYg1cdD4H+yRRdI0awRaI3SL34yy8Q==",
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-2.8.0.tgz",
|
||||
"integrity": "sha512-MMU5wG9bRPccFjxfc5vORw8YkjfFrPUtnoPQ/1WOqud5l3z3a318WAGqzSloPUMOMPD0c7x0vPQBsUgteliC2w==",
|
||||
"requires": {
|
||||
"@types/offscreencanvas": "~2019.3.0",
|
||||
"@types/seedrandom": "2.4.27",
|
||||
|
@ -83,26 +83,26 @@
|
|||
}
|
||||
},
|
||||
"@tensorflow/tfjs-data": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.7.0.tgz",
|
||||
"integrity": "sha512-gsVklCwqqlxhykI7U2Uy5c2hjommQCAi+3y2/LER4TNtzQTzWaGKyIXvuLuL0tE896yuzXILIMZhkUjDmUiGxA==",
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-data/-/tfjs-data-2.8.0.tgz",
|
||||
"integrity": "sha512-U0UbL6Hffcq6O6r6Isb6cPM9GxXUcU998gcPSIBeTMaGnBGbm/tl6Nf7f8SbGV4076ojeeFr4+0ly8heyBnfew==",
|
||||
"requires": {
|
||||
"@types/node-fetch": "^2.1.2",
|
||||
"node-fetch": "~2.6.1"
|
||||
}
|
||||
},
|
||||
"@tensorflow/tfjs-layers": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.7.0.tgz",
|
||||
"integrity": "sha512-78zsD2LLrHQuDYv0EeV83LiF0M69lKsBfuTB3FIBgS85gapZPyHh4wooKda2Y4H9EtLogU+C6bArZuDo8PaX+g=="
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-layers/-/tfjs-layers-2.8.0.tgz",
|
||||
"integrity": "sha512-g7ZEXyo46osVbJAR90KXbwWr82OlcqZOohSEIMXSZ/egnXLQpQqUBmGg8p8Nw2es6jmpSfOBDivv1z59HExVvg=="
|
||||
},
|
||||
"@tensorflow/tfjs-node": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.7.0.tgz",
|
||||
"integrity": "sha512-0cWplm7AE40gi2llqoAp+lD/0X3dVJ8kb7Arrqb5lMhShRWUFZpULH+F0fJI6Yax4LBTzBi2SZKGL/O8krZsxg==",
|
||||
"version": "2.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node/-/tfjs-node-2.8.0.tgz",
|
||||
"integrity": "sha512-u62AqvIS9jaMrw09Hl/Dy1YKiqEkltQ1c+RoJuW0XE685bjRue7FRmkJy69+EvsnXLY6gvPR1y+LguTPql8ZwA==",
|
||||
"requires": {
|
||||
"@tensorflow/tfjs": "2.7.0",
|
||||
"@tensorflow/tfjs-core": "2.7.0",
|
||||
"@tensorflow/tfjs": "2.8.0",
|
||||
"@tensorflow/tfjs-core": "2.8.0",
|
||||
"adm-zip": "^0.4.11",
|
||||
"google-protobuf": "^3.9.2",
|
||||
"https-proxy-agent": "^2.2.1",
|
||||
|
@ -113,9 +113,9 @@
|
|||
}
|
||||
},
|
||||
"@types/node": {
|
||||
"version": "14.14.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.7.tgz",
|
||||
"integrity": "sha512-Zw1vhUSQZYw+7u5dAwNbIA9TuTotpzY/OF7sJM9FqPOF3SPjKnxrjoTktXDZgUjybf4cWVBP7O8wvKdSaGHweg=="
|
||||
"version": "14.14.14",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-14.14.14.tgz",
|
||||
"integrity": "sha512-UHnOPWVWV1z+VV8k6L1HhG7UbGBgIdghqF3l9Ny9ApPghbjICXkUJSd/b9gOgQfjM1r+37cipdw/HJ3F6ICEnQ=="
|
||||
},
|
||||
"@types/node-fetch": {
|
||||
"version": "2.5.7",
|
||||
|
@ -314,9 +314,9 @@
|
|||
"integrity": "sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4="
|
||||
},
|
||||
"core-js": {
|
||||
"version": "3.7.0",
|
||||
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.7.0.tgz",
|
||||
"integrity": "sha512-NwS7fI5M5B85EwpWuIwJN4i/fbisQUwLwiSNUWeXlkAZ0sbBjLEvLvFLf1uzAUV66PcEPt4xCGCmOZSxVf3xzA=="
|
||||
"version": "3.8.1",
|
||||
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.8.1.tgz",
|
||||
"integrity": "sha512-9Id2xHY1W7m8hCl8NkhQn5CufmF/WuR30BTRewvCXc1aZd3kMECwNZ69ndLbekKfakw9Rf2Xyc+QR6E7Gg+obg=="
|
||||
},
|
||||
"core-util-is": {
|
||||
"version": "1.0.2",
|
||||
|
@ -485,9 +485,9 @@
|
|||
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
|
||||
},
|
||||
"ini": {
|
||||
"version": "1.3.5",
|
||||
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.5.tgz",
|
||||
"integrity": "sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw=="
|
||||
"version": "1.3.8",
|
||||
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
|
||||
"integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
|
||||
},
|
||||
"is-fullwidth-code-point": {
|
||||
"version": "1.0.0",
|
||||
|
@ -885,9 +885,9 @@
|
|||
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="
|
||||
},
|
||||
"yargs": {
|
||||
"version": "16.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.1.1.tgz",
|
||||
"integrity": "sha512-hAD1RcFP/wfgfxgMVswPE+z3tlPFtxG8/yWUrG2i17sTWGCGqWnxKcLTF4cUKDUK8fzokwsmO9H0TDkRbMHy8w==",
|
||||
"version": "16.2.0",
|
||||
"resolved": "https://registry.npmjs.org/yargs/-/yargs-16.2.0.tgz",
|
||||
"integrity": "sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw==",
|
||||
"requires": {
|
||||
"cliui": "^7.0.2",
|
||||
"escalade": "^3.1.1",
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
"author": "Starbeamrainbowlabs",
|
||||
"license": "MPL-2.0",
|
||||
"dependencies": {
|
||||
"@tensorflow/tfjs-node": "^2.7.0",
|
||||
"@tensorflow/tfjs-node": "^2.8.0",
|
||||
"applause-cli": "^1.5.1"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,9 @@ export default async function () {
|
|||
let cli = new CliParser(path.resolve(__dirname, "../package.json"));
|
||||
cli.subcommand("train", "Trains a new AI")
|
||||
.argument("input", "The input directory containing the training data", null, "string")
|
||||
.argument("output", "Path to the output directory to save the trained AI to");
|
||||
.argument("output", "Path to the output directory to save the trained AI to")
|
||||
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
|
||||
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
|
||||
cli.subcommand("predict", "Predicts the genres of the specified image")
|
||||
.argument("input", "Path to the input image")
|
||||
.argument("ai-model", "Path to the saved AI model to load");
|
||||
|
|
40
src/lib/Categories.mjs
Normal file
40
src/lib/Categories.mjs
Normal file
|
@ -0,0 +1,40 @@
|
|||
"use strict";
|
||||
|
||||
import fs from 'fs';
|
||||
|
||||
/**
|
||||
* Represents a list of categories.
|
||||
*/
|
||||
class Categories {
|
||||
constructor(in_filename) {
|
||||
this.filename = in_filename;
|
||||
if(!fs.existsSync(this.filename))
|
||||
throw new Exception(`Error: No such file or directory '${this.filename}' (it should exist and have 1 category name per line).`);
|
||||
|
||||
/**
|
||||
* A list of category names.
|
||||
* @type {string[]}
|
||||
*/
|
||||
this.values = fs.readFileSync(this.filename, "utf-8")
|
||||
.replace(/\r\n?/, "\n")
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter((line) => line.length > 0);
|
||||
|
||||
/**
|
||||
* The total number of categories.
|
||||
* @type {Number}
|
||||
*/
|
||||
this.count = this.values.length;
|
||||
}
|
||||
|
||||
to_id(cat_name) {
|
||||
return this.values.indexOf(cat_name);
|
||||
}
|
||||
|
||||
to_name(id) {
|
||||
return this.values[id];
|
||||
}
|
||||
}
|
||||
|
||||
export default Categories;
|
|
@ -5,7 +5,7 @@ import path from 'path';
|
|||
|
||||
import tf from '@tensorflow/tfjs-node';
|
||||
|
||||
import genres from './Genres.mjs';
|
||||
import Categories from './Categories.mjs';
|
||||
|
||||
class DataPreprocessor {
|
||||
constructor(dir_input) {
|
||||
|
@ -14,21 +14,14 @@ class DataPreprocessor {
|
|||
|
||||
this.dir_input_train = path.join(dir_input, "train");
|
||||
this.dir_input_validate = path.join(dir_input, "validate");
|
||||
this.file_input_cats = path.join(dir_input, "categories.txt");
|
||||
|
||||
if(!fs.existsSync(this.dir_input_train))
|
||||
throw new Error(`Error: Failed to locate the directory containing the training data.`);
|
||||
if(!fs.existsSync(this.dir_input_validate))
|
||||
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
|
||||
|
||||
this.genres_count = genres.length;
|
||||
}
|
||||
|
||||
genre2id(genre_name) {
|
||||
return genres.indexOf(genre_name);
|
||||
}
|
||||
|
||||
id2genre(id) {
|
||||
return genres[id];
|
||||
this.cats = new Categories(this.file_input_cats);
|
||||
}
|
||||
|
||||
async *data_from_dir(dirpath) {
|
||||
|
@ -47,27 +40,40 @@ class DataPreprocessor {
|
|||
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
|
||||
|
||||
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
|
||||
let next_genres = filename.replace(/\.[a-zA-Z]+$/, "")
|
||||
let next_cats = filename.replace(/\.[a-zA-Z]+$/, "")
|
||||
.split(",")
|
||||
.slice(1)
|
||||
.map(this.genre2id);
|
||||
.map(this.cats.to_id.bind(this.cats));
|
||||
|
||||
|
||||
let genres_pretensor = Array(this.genres_count).fill(0);
|
||||
for(let genre_id of next_genres)
|
||||
genres_pretensor[genre_id] = 1;
|
||||
let pretensor = Array(this.cats.count).fill(0);
|
||||
for(let cat_id of next_cats)
|
||||
pretensor[cat_id] = 1;
|
||||
|
||||
let genres_tensor = tf.tensor(genres_pretensor);
|
||||
let cats_tensor = tf.tensor(pretensor);
|
||||
|
||||
// console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`);
|
||||
// console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`);
|
||||
|
||||
yield {
|
||||
let result = {
|
||||
xs: imagetensor,
|
||||
ys: genres_tensor
|
||||
ys: cats_tensor
|
||||
};
|
||||
// console.log(`[DEBUG] yielding xs`, result.xs.dataSync());
|
||||
// console.log(`[DEBUG] ys`); result.ys.print();
|
||||
// console.log(`--------------------`);
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a Categories object instance that represents the list of
|
||||
* categories.
|
||||
* @return {Categories}
|
||||
*/
|
||||
categories() {
|
||||
return this.cats;
|
||||
}
|
||||
|
||||
dataset_train() {
|
||||
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
|
||||
}
|
||||
|
|
|
@ -7,16 +7,19 @@ import tf from '@tensorflow/tfjs-node';
|
|||
// import tf from '@tensorflow/tfjs-node-gpu';
|
||||
// import tf from '@tensorflow/tfjs';
|
||||
|
||||
import genres from './Genres.mjs';
|
||||
import make_model from './model.mjs';
|
||||
|
||||
class FilmPredictor {
|
||||
constructor(settings) {
|
||||
constructor(settings, in_categories) {
|
||||
this.settings = settings;
|
||||
|
||||
this.genres = genres.length;
|
||||
this.cats = in_categories;
|
||||
|
||||
this.image_size = 256;
|
||||
this.use_crossentropy = false;
|
||||
|
||||
this.batch_size = 32;
|
||||
this.prefetch = 4;
|
||||
this.prefetch = 64;
|
||||
}
|
||||
|
||||
async init(dirpath = null) {
|
||||
|
@ -29,7 +32,16 @@ class FilmPredictor {
|
|||
if(!fs.existsSync(this.dir_checkpoints))
|
||||
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
||||
|
||||
this.make_model();
|
||||
await fs.promises.copyFile(
|
||||
this.cats.filename,
|
||||
path.join(this.settings.output, "categories.txt")
|
||||
);
|
||||
|
||||
this.model = make_model(
|
||||
this.cats.count,
|
||||
this.image_size,
|
||||
this.use_crossentropy
|
||||
);
|
||||
|
||||
this.model.summary();
|
||||
}
|
||||
|
@ -43,57 +55,6 @@ class FilmPredictor {
|
|||
console.error(`>>> Model loading complete`);
|
||||
}
|
||||
|
||||
make_model() {
|
||||
console.error(`>>> Creating new model`);
|
||||
this.model = tf.sequential();
|
||||
|
||||
this.model.add(tf.layers.conv2d({
|
||||
name: "conv2d_1",
|
||||
dataFormat: "channelsLast",
|
||||
inputShape: [ 256, 256, 3 ],
|
||||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "relu"
|
||||
}));
|
||||
this.model.add(tf.layers.conv2d({
|
||||
name: "conv2d_2",
|
||||
dataFormat: "channelsLast",
|
||||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "relu"
|
||||
}));
|
||||
this.model.add(tf.layers.conv2d({
|
||||
name: "conv2d_3",
|
||||
dataFormat: "channelsLast",
|
||||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "relu"
|
||||
}));
|
||||
|
||||
// Reshape and flatten
|
||||
let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape;
|
||||
this.model.add(tf.layers.reshape({
|
||||
name: "reshape",
|
||||
targetShape: [
|
||||
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
|
||||
]
|
||||
}));
|
||||
this.model.add(tf.layers.dense({
|
||||
name: "dense",
|
||||
units: this.genres,
|
||||
activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
|
||||
}));
|
||||
|
||||
this.model.compile({
|
||||
optimizer: tf.train.adam(),
|
||||
loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
|
||||
metrics: [ "mse" /* meanSquaredError */ ]
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async train(dataset_train, dataset_validate) {
|
||||
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
|
||||
|
@ -142,7 +103,7 @@ class FilmPredictor {
|
|||
for(let i = 0; i < arr.length; i++) {
|
||||
if(arr[i] < 0.5)
|
||||
continue;
|
||||
result.push(genres[i]);
|
||||
result.push(this.cats.values[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
67
src/lib/model.mjs
Normal file
67
src/lib/model.mjs
Normal file
|
@ -0,0 +1,67 @@
|
|||
"use strict";
|
||||
|
||||
import tf from '@tensorflow/tfjs-node';
|
||||
|
||||
export default function(cats_count, image_size, use_crossentropy = false) {
|
||||
console.error(`>>> Creating new model`);
|
||||
let model = tf.sequential();
|
||||
|
||||
model.add(tf.layers.conv2d({
|
||||
name: "conv2d_1",
|
||||
dataFormat: "channelsLast",
|
||||
inputShape: [ image_size, image_size, 3 ],
|
||||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 1,
|
||||
activation: "relu"
|
||||
}));
|
||||
if(image_size > 32) {
|
||||
model.add(tf.layers.conv2d({
|
||||
name: "conv2d_2",
|
||||
dataFormat: "channelsLast",
|
||||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "relu"
|
||||
}));
|
||||
model.add(tf.layers.conv2d({
|
||||
name: "conv2d_3",
|
||||
dataFormat: "channelsLast",
|
||||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "relu"
|
||||
}));
|
||||
}
|
||||
|
||||
// Reshape and flatten
|
||||
let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape;
|
||||
model.add(tf.layers.reshape({
|
||||
name: "reshape",
|
||||
targetShape: [
|
||||
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
|
||||
]
|
||||
}));
|
||||
model.add(tf.layers.dense({
|
||||
name: "dense",
|
||||
units: cats_count,
|
||||
activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
|
||||
}));
|
||||
|
||||
let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
|
||||
metrics = [ "mse" /* meanSquaredError */ ];
|
||||
|
||||
if(use_crossentropy) {
|
||||
console.log(`>>> Using categorical cross-entropy loss`);
|
||||
loss = "categoricalCrossentropy";
|
||||
metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ];
|
||||
}
|
||||
|
||||
model.compile({
|
||||
optimizer: tf.train.adam(),
|
||||
loss,
|
||||
metrics
|
||||
});
|
||||
|
||||
return model;
|
||||
}
|
|
@ -1,8 +1,10 @@
|
|||
"use strict";
|
||||
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
|
||||
import FilmPredictor from '../../lib/FilmPredictor.mjs';
|
||||
import Categories from '../../lib/Categories.mjs';
|
||||
|
||||
export default async function(settings) {
|
||||
if(!fs.existsSync(settings.input)) {
|
||||
|
@ -14,7 +16,10 @@ export default async function(settings) {
|
|||
process.exit(1);
|
||||
}
|
||||
|
||||
let model = new FilmPredictor(settings);
|
||||
let model = new FilmPredictor(
|
||||
settings,
|
||||
new Categories(path.join(settings.ai_model, "categories.txt"))
|
||||
);
|
||||
await model.init(settings.ai_model); // We're training a new model here
|
||||
|
||||
let result = await model.predict(settings.input);
|
||||
|
|
|
@ -24,7 +24,9 @@ export default async function(settings) {
|
|||
await fs.promises.mkdir(settings.output, { recursive: true, mode: 0o755 });
|
||||
|
||||
let preprocessor = new DataPreprocessor(settings.input);
|
||||
let model = new FilmPredictor(settings);
|
||||
let model = new FilmPredictor(settings, preprocessor.cats);
|
||||
model.image_size = settings.image_size;
|
||||
model.use_crossentropy = settings.cross_entropy;
|
||||
await model.init(); // We're training a new model here
|
||||
|
||||
await model.train(
|
||||
|
|
Loading…
Reference in a new issue