"use strict"; import path from 'path'; import fs from 'fs'; import tf from '@tensorflow/tfjs-node'; // import tf from '@tensorflow/tfjs-node-gpu'; // import tf from '@tensorflow/tfjs'; import genres from './Genres.mjs'; class FilmPredictor { constructor(settings) { this.settings = settings; this.genres = genres.length; this.batch_size = 32; this.prefetch = 4; } async init(dirpath = null) { if(dirpath !== null) await this.load_model(dirpath); else { if(!fs.existsSync(this.settings.output)) await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 }); this.dir_checkpoints = path.join(this.settings.output, "checkpoints"); if(!fs.existsSync(this.dir_checkpoints)) await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 }); this.make_model(); this.model.summary(); } } async load_model(dirpath) { if(!fs.existsSync(dirpath)) throw new Error(`Error: The directory ${dirpath} doesn't exist.`); console.error(`>>> Loading model from '${dirpath}'`); this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`); console.error(`>>> Model loading complete`); } make_model() { console.error(`>>> Creating new model`); this.model = tf.sequential(); this.model.add(tf.layers.conv2d({ name: "conv2d_1", dataFormat: "channelsLast", inputShape: [ 256, 256, 3 ], kernelSize: 5, filters: 3, strides: 2, activation: "relu" })); this.model.add(tf.layers.conv2d({ name: "conv2d_2", dataFormat: "channelsLast", kernelSize: 5, filters: 3, strides: 2, activation: "relu" })); this.model.add(tf.layers.conv2d({ name: "conv2d_3", dataFormat: "channelsLast", kernelSize: 5, filters: 3, strides: 2, activation: "relu" })); // Reshape and flatten let cnn_stack_output_shape = this.model.getLayer("conv2d_3").outputShape; this.model.add(tf.layers.reshape({ name: "reshape", targetShape: [ cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3] ] })); this.model.add(tf.layers.dense({ name: "dense", units: this.genres, activation: "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead })); this.model.compile({ optimizer: tf.train.adam(), loss: "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq metrics: [ "mse" /* meanSquaredError */ ] }); } async train(dataset_train, dataset_validate) { dataset_train = dataset_train.batch(this.batch_size).prefetch(this.prefetch); dataset_validate = dataset_validate.batch(this.batch_size).prefetch(this.prefetch); await this.model.fitDataset(dataset_train, { epochs: 50, verbose: 1, validationData: dataset_validate, yieldEvery: "batch", shuffle: false, callbacks: { onEpochEnd: async (epoch, metrics) => { console.error(`>>> Epoch ${epoch} complete, metrics:`, metrics); let dir_output_checkpoint = `file://${path.join(this.dir_checkpoints, `${epoch.toString()}`)}`; await Promise.all([ this.model.save(dir_output_checkpoint), fs.promises.appendFile( path.join(this.settings.output, `metrics.stream.json`), JSON.stringify(metrics) + "\n" ) ]); } } }) } async predict(imagefilepath) { if(!fs.existsSync(imagefilepath)) throw new Error(`Error: No file exists at '${imagefilepath}'`); let image_data = await fs.promises.readFile(imagefilepath); let imagetensor = tf.tidy(() => { let image = tf.node.decodeImage( image_data, 3 // channels ); return image.reshape([1, ...image.shape ]); }) let result_array = (await this.model.predict(imagetensor).array())[0]; return this.array2genres(result_array); } array2genres(arr) { let result = []; for(let i = 0; i < arr.length; i++) { if(arr[i] < 0.5) continue; result.push(genres[i]); } return result; } } export default FilmPredictor;