"use strict"; import fs from 'fs'; import path from 'path'; import tf from '@tensorflow/tfjs-node'; import Categories from './Categories.mjs'; class DataPreprocessor { constructor(dir_input) { if(!fs.existsSync(dir_input)) throw new Error(`Error: The input directory '${dir_input}' does not exist.`); this.dir_input_train = path.join(dir_input, "train"); this.dir_input_validate = path.join(dir_input, "validate"); this.file_input_cats = path.join(dir_input, "categories.txt"); if(!fs.existsSync(this.dir_input_train)) throw new Error(`Error: Failed to locate the directory containing the training data.`); if(!fs.existsSync(this.dir_input_validate)) throw new Error(`Error: Failed to locate the directory containing the validation data.`); this.cats = new Categories(this.file_input_cats); } async *data_from_dir(dirpath) { if(!fs.existsSync(dirpath)) throw new Error(`[DataPreprocessor/dataset_dir] Error: The specified directory does not exist.`); let filenames = await fs.promises.readdir(dirpath); for(let filename of filenames) { let imagetensor = tf.node.decodeImage( await fs.promises.readFile(path.join(dirpath, filename)), 3 // channels ); if(imagetensor.shape[0] == 256 && imagetensor.shape[1] == 256 && imagetensor.shape[0] == 3) throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`); // Strip the file extension, then split into an array of genres, and finally remove the id from the beginning let next_cats = filename.replace(/\.[a-zA-Z]+$/, "") .split(",") .slice(1) .map(this.cats.to_id.bind(this.cats)); let pretensor = Array(this.cats.count).fill(0); for(let cat_id of next_cats) pretensor[cat_id] = 1; let cats_tensor = tf.tensor(pretensor); // console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`); let result = { xs: imagetensor, ys: cats_tensor }; // console.log(`[DEBUG] yielding xs`, result.xs.dataSync()); // console.log(`[DEBUG] ys`); result.ys.print(); // console.log(`--------------------`); yield result; } } /** * Returns a Categories object instance that represents the list of * categories. * @return {Categories} */ categories() { return this.cats; } dataset_train() { return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train)); } dataset_validate() { return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_validate)); } } export default DataPreprocessor;