"use strict"; import fs from 'fs'; import path from 'path'; import tf from '@tensorflow/tfjs-node'; import genres from './Genres.mjs'; class DataPreprocessor { constructor(dir_input) { if(!fs.existsSync(dir_input)) throw new Error(`Error: The input directory '${dir_input}' does not exist.`); this.dir_input_train = path.join(dir_input, "train"); this.dir_input_validate = path.join(dir_input, "validate"); if(!fs.existsSync(this.dir_input_train)) throw new Error(`Error: Failed to locate the directory containing the training data.`); if(!fs.existsSync(this.dir_input_validate)) throw new Error(`Error: Failed to locate the directory containing the validation data.`); this.genres_count = genres.length; } genre2id(genre_name) { return genres.indexOf(genre_name); } id2genre(id) { return genres[id]; } async *data_from_dir(dirpath) { if(!fs.existsSync(dirpath)) throw new Error(`[DataPreprocessor/dataset_dir] Error: The specified directory does not exist.`); let filenames = await fs.promises.readdir(dirpath); for(let filename of filenames) { let imagetensor = tf.node.decodeImage( await fs.promises.readFile(path.join(dirpath, filename)), 3 // channels ); if(imagetensor.shape[0] == 256 && imagetensor.shape[1] == 256 && imagetensor.shape[0] == 3) throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`); // Strip the file extension, then split into an array of genres, and finally remove the id from the beginning let next_genres = filename.replace(/\.[a-zA-Z]+$/, "") .split(",") .slice(1) .map(this.genre2id); let genres_pretensor = Array(this.genres_count).fill(0); for(let genre_id of next_genres) genres_pretensor[genre_id] = 1; let genres_tensor = tf.tensor(genres_pretensor); // console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`); yield { xs: imagetensor, ys: genres_tensor }; } } dataset_train() { return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train)); } dataset_validate() { return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_validate)); } } export default DataPreprocessor;