152 lines
4.0 KiB
JavaScript
152 lines
4.0 KiB
JavaScript
"use strict";
|
|
|
|
import path from 'path';
|
|
import fs from 'fs';
|
|
|
|
import tf from '@tensorflow/tfjs-node';
|
|
// import tf from '@tensorflow/tfjs-node-gpu';
|
|
// import tf from '@tensorflow/tfjs';
|
|
|
|
import genres from './Genres.mjs';
|
|
|
|
class FilmPredictor {
|
|
constructor(settings) {
|
|
this.settings = settings;
|
|
|
|
this.genres = genres.length;
|
|
|
|
this.batch_size = 32;
|
|
this.prefetch = 4;
|
|
}
|
|
|
|
async init(dirpath = null) {
|
|
if(dirpath !== null)
|
|
await this.load_model(dirpath);
|
|
else {
|
|
if(!fs.existsSync(this.settings.output))
|
|
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
|
|
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
|
|
if(!fs.existsSync(this.dir_checkpoints))
|
|
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
|
|
|
this.make_model();
|
|
|
|
this.model.summary();
|
|
}
|
|
}
|
|
|
|
async load_model(dirpath) {
|
|
if(!fs.existsSync(dirpath))
|
|
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
|
|
console.error(`>>> Loading model from '${dirpath}'`);
|
|
this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`);
|
|
console.error(`>>> Model loading complete`);
|
|
}
|
|
|
|
make_model() {
|
|
console.error(`>>> Creating new model`);
|
|
this.model = tf.sequential();
|
|
|
|
this.model.add(tf.layers.conv2d({
|
|
name: "conv2d_1",
|
|
dataFormat: "channelsLast",
|
|
inputShape: [ 256, 256, 3 ],
|
|
kernelSize: 5,
|
|
filters: 3,
|
|
strides: 2,
|
|
activation: "relu"
|
|
}));
|
|
this.model.add(tf.layers.conv2d({
|
|
name: "conv2d_2",
|
|
dataFormat: "channelsLast",
|
|
kernelSize: 5,
|
|
filters: 3,
|
|
strides: 2,
|
|
activation: "relu"
|
|
}));
|
|
this.model.add(tf.layers.conv2d({
|
|
name: "conv2d_3",
|
|
dataFormat: "channelsLast",
|
|
kernelSize: 5,
|
|
filters: 3,
|
|
strides: 2,
|
|
activation: "relu"
|
|
}));
|
|
|
|
// Reshape and flatten
|
|
let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape;
|
|
this.model.add(tf.layers.reshape({
|
|
name: "reshape",
|
|
targetShape: [
|
|
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
|
|
]
|
|
}));
|
|
this.model.add(tf.layers.dense({
|
|
name: "dense",
|
|
units: this.genres,
|
|
activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
|
|
}));
|
|
|
|
this.model.compile({
|
|
optimizer: tf.train.adam(),
|
|
loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
|
|
metrics: [ "mse" /* meanSquaredError */ ]
|
|
});
|
|
}
|
|
|
|
|
|
async train(dataset_train, dataset_validate) {
|
|
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
|
|
dataset_validate = dataset_validate.batch(this.batch_size).prefetch(this.prefetch);
|
|
|
|
await this.model.fitDataset(dataset_train, {
|
|
epochs: 50,
|
|
verbose: 1,
|
|
validationData: dataset_validate,
|
|
yieldEvery: "batch",
|
|
shuffle: false,
|
|
callbacks: {
|
|
onEpochEnd: async (epoch, metrics) => {
|
|
console.error(`>>> Epoch ${epoch} complete, metrics:`, metrics);
|
|
let dir_output_checkpoint = `file://${path.join(this.dir_checkpoints, `${epoch.toString()}`)}`;
|
|
await Promise.all([
|
|
this.model.save(dir_output_checkpoint),
|
|
fs.promises.appendFile(
|
|
path.join(this.settings.output, `metrics.stream.json`), JSON.stringify(metrics) + "\n"
|
|
)
|
|
]);
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
async predict(imagefilepath) {
|
|
if(!fs.existsSync(imagefilepath))
|
|
throw new Error(`Error: No file exists at '${imagefilepath}'`);
|
|
|
|
let image_data = await fs.promises.readFile(imagefilepath);
|
|
|
|
let imagetensor = tf.tidy(() => {
|
|
let image = tf.node.decodeImage(
|
|
image_data,
|
|
3 // channels
|
|
);
|
|
return image.reshape([1, ...image.shape ]);
|
|
})
|
|
let result_array = (await this.model.predict(imagetensor).array())[0];
|
|
return this.array2genres(result_array);
|
|
}
|
|
|
|
array2genres(arr) {
|
|
let result = [];
|
|
for(let i = 0; i < arr.length; i++) {
|
|
if(arr[i] < 0.5)
|
|
continue;
|
|
result.push(genres[i]);
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
|
|
export default FilmPredictor;
|