film-poster-genres/src/lib/DataPreprocessor.mjs

87 lines
2.6 KiB
JavaScript

"use strict";
import fs from 'fs';
import path from 'path';
import tf from '@tensorflow/tfjs-node';
import Categories from './Categories.mjs';
class DataPreprocessor {
constructor(dir_input) {
if(!fs.existsSync(dir_input))
throw new Error(`Error: The input directory '${dir_input}' does not exist.`);
this.dir_input_train = path.join(dir_input, "train");
this.dir_input_validate = path.join(dir_input, "validate");
this.file_input_cats = path.join(dir_input, "categories.txt");
if(!fs.existsSync(this.dir_input_train))
throw new Error(`Error: Failed to locate the directory containing the training data.`);
if(!fs.existsSync(this.dir_input_validate))
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
this.cats = new Categories(this.file_input_cats);
}
async *data_from_dir(dirpath) {
if(!fs.existsSync(dirpath))
throw new Error(`[DataPreprocessor/dataset_dir] Error: The specified directory does not exist.`);
let filenames = await fs.promises.readdir(dirpath);
for(let filename of filenames) {
let imagetensor = tf.node.decodeImage(
await fs.promises.readFile(path.join(dirpath, filename)),
3 // channels
);
if(imagetensor.shape[0] == 256
&& imagetensor.shape[1] == 256
&& imagetensor.shape[0] == 3)
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
let next_cats = filename.replace(/\.[a-zA-Z]+$/, "")
.split(",")
.slice(1)
.map(this.cats.to_id.bind(this.cats));
let pretensor = Array(this.cats.count).fill(0);
for(let cat_id of next_cats)
pretensor[cat_id] = 1;
let cats_tensor = tf.tensor(pretensor);
// console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`);
let result = {
xs: imagetensor,
ys: cats_tensor
};
// console.log(`[DEBUG] yielding xs`, result.xs.dataSync());
// console.log(`[DEBUG] ys`); result.ys.print();
// console.log(`--------------------`);
yield result;
}
}
/**
* Returns a Categories object instance that represents the list of
* categories.
* @return {Categories}
*/
categories() {
return this.cats;
}
dataset_train() {
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
}
dataset_validate() {
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_validate));
}
}
export default DataPreprocessor;