film-poster-genres/src/lib/FilmPredictor.mjs

113 lines
2.9 KiB
JavaScript

"use strict";
import path from 'path';
import fs from 'fs';
import tf from '@tensorflow/tfjs-node';
// import tf from '@tensorflow/tfjs-node-gpu';
// import tf from '@tensorflow/tfjs';
import make_model from './model.mjs';
class FilmPredictor {
constructor(settings, in_categories) {
this.settings = settings;
this.cats = in_categories;
this.image_size = 256;
this.use_crossentropy = false;
this.batch_size = 32;
this.prefetch = 64;
}
async init(dirpath = null) {
if(dirpath !== null)
await this.load_model(dirpath);
else {
if(!fs.existsSync(this.settings.output))
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
await fs.promises.copyFile(
this.cats.filename,
path.join(this.settings.output, "categories.txt")
);
this.model = make_model(
this.cats.count,
this.image_size,
this.use_crossentropy
);
this.model.summary();
}
}
async load_model(dirpath) {
if(!fs.existsSync(dirpath))
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
console.error(`>>> Loading model from '${dirpath}'`);
this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`);
console.error(`>>> Model loading complete`);
}
async train(dataset_train, dataset_validate) {
dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch);
dataset_validate = dataset_validate.batch(this.batch_size).prefetch(this.prefetch);
await this.model.fitDataset(dataset_train, {
epochs: 50,
verbose: 1,
validationData: dataset_validate,
yieldEvery: "batch",
shuffle: false,
callbacks: {
onEpochEnd: async (epoch, metrics) => {
console.error(`>>> Epoch ${epoch} complete, metrics:`, metrics);
let dir_output_checkpoint = `file://${path.join(this.dir_checkpoints, `${epoch.toString()}`)}`;
await Promise.all([
this.model.save(dir_output_checkpoint),
fs.promises.appendFile(
path.join(this.settings.output, `metrics.stream.json`), JSON.stringify(metrics) + "\n"
)
]);
}
}
})
}
async predict(imagefilepath) {
if(!fs.existsSync(imagefilepath))
throw new Error(`Error: No file exists at '${imagefilepath}'`);
let image_data = await fs.promises.readFile(imagefilepath);
let imagetensor = tf.tidy(() => {
let image = tf.node.decodeImage(
image_data,
3 // channels
);
return image.reshape([1, ...image.shape ]);
})
let result_array = (await this.model.predict(imagetensor).array())[0];
return this.array2genres(result_array);
}
array2genres(arr) {
let result = [];
for(let i = 0; i < arr.length; i++) {
if(arr[i] < 0.5)
continue;
result.push(this.cats.values[i]);
}
return result;
}
}
export default FilmPredictor;