Predicting film genres from their posters with Tensorflow.js
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

127 lines
3.3 KiB

"use strict";
import path from 'path';
import fs from 'fs';
import tf from '@tensorflow/tfjs-node';
import genres from './Genres.mjs';
class FilmPredictor {
constructor(settings) {
this.settings = settings;
this.genres = genres.length;
this.batch_size = 32;
this.prefetch = 4;
}
async init(dirpath = null) {
if(!fs.existsSync(this.settings.output))
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
if(dirpath !== null)
await this.load_model(dirpath);
else
this.make_model();
this.model.summary();
}
async load_model(dirpath) {
if(!fs.existsSync(dirpath)) {
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
}
console.error(`>>> Loading model from '${dirpath}'`);
this.model = await this.tf.loadLayersModel(`file://${dirpath}`);
console.error(`>>> Model loading complete`);
}
make_model() {
console.error(`>>> Creating new model`);
this.model = tf.sequential();
this.model.add(tf.layers.conv2d({
name: "conv2d_1",
dataFormat: "channelsLast",
inputShape: [ 256, 256, 3 ],
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
this.model.add(tf.layers.conv2d({
name: "conv2d_2",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
this.model.add(tf.layers.conv2d({
name: "conv2d_3",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
// Reshape and flatten
let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape;
this.model.add(tf.layers.reshape({
name: "reshape",
targetShape: [
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
]
}));
this.model.add(tf.layers.dense({
name: "dense",
units: this.genres,
activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
}));
this.model.compile({
optimizer: tf.train.adam(),
loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics: [ "mse" /* meanSquaredError */ ]
});
}
async train(dataset_train, dataset_validate) {
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
dataset_validate = dataset_validate.batch(this.batch_size).prefetch(this.prefetch);
await this.model.fitDataset(dataset_train, {
epochs: 50,
verbose: 1,
validationData: dataset_validate,
yieldEvery: "batch",
shuffle: false,
callbacks: {
onEpochEnd: async (epoch, metrics) => {
console.error(`>>> Epoch ${epoch} complete, metrics:`, metrics);
let dir_output_checkpoint = `file://${path.join(this.dir_checkpoints, `${epoch.toString()}`)}`;
await Promise.all([
this.model.save(dir_output_checkpoint),
fs.promises.appendFile(
path.join(this.settings.output, `metrics.stream.json`), JSON.stringify(metrics) + "\n"
)
]);
}
}
})
}
async predict(image) {
}
}
export default FilmPredictor;