81 lines
2.3 KiB
JavaScript
81 lines
2.3 KiB
JavaScript
|
"use strict";
|
||
|
|
||
|
import fs from 'fs';
|
||
|
import path from 'path';
|
||
|
|
||
|
import tf from '@tensorflow/tfjs-node';
|
||
|
|
||
|
import genres from './Genres.mjs';
|
||
|
|
||
|
class DataPreprocessor {
|
||
|
constructor(dir_input) {
|
||
|
if(!fs.existsSync(dir_input))
|
||
|
throw new Error(`Error: The input directory '${dir_input}' does not exist.`);
|
||
|
|
||
|
this.dir_input_train = path.join(dir_input, "train");
|
||
|
this.dir_input_validate = path.join(dir_input, "validate");
|
||
|
|
||
|
if(!fs.existsSync(this.dir_input_train))
|
||
|
throw new Error(`Error: Failed to locate the directory containing the training data.`);
|
||
|
if(!fs.existsSync(this.dir_input_validate))
|
||
|
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
|
||
|
|
||
|
this.genres_count = genres.length;
|
||
|
}
|
||
|
|
||
|
genre2id(genre_name) {
|
||
|
return genres.indexOf(genre_name);
|
||
|
}
|
||
|
|
||
|
id2genre(id) {
|
||
|
return genres[id];
|
||
|
}
|
||
|
|
||
|
async *data_from_dir(dirpath) {
|
||
|
if(!fs.existsSync(dirpath))
|
||
|
throw new Error(`[DataPreprocessor/dataset_dir] Error: The specified directory does not exist.`);
|
||
|
|
||
|
let filenames = await fs.promises.readdir(dirpath);
|
||
|
for(let filename of filenames) {
|
||
|
let imagetensor = tf.node.decodeImage(
|
||
|
await fs.promises.readFile(path.join(dirpath, filename)),
|
||
|
3 // channels
|
||
|
);
|
||
|
if(imagetensor.shape[0] == 256
|
||
|
&& imagetensor.shape[1] == 256
|
||
|
&& imagetensor.shape[0] == 3)
|
||
|
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
|
||
|
|
||
|
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
|
||
|
let next_genres = filename.replace(/\.[a-zA-Z]+$/, "")
|
||
|
.split(",")
|
||
|
.slice(1)
|
||
|
.map(this.genre2id);
|
||
|
|
||
|
|
||
|
let genres_pretensor = Array(this.genres_count).fill(0);
|
||
|
for(let genre_id of next_genres)
|
||
|
genres_pretensor[genre_id] = 1;
|
||
|
|
||
|
let genres_tensor = tf.tensor(genres_pretensor);
|
||
|
|
||
|
// console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`);
|
||
|
|
||
|
yield {
|
||
|
xs: imagetensor,
|
||
|
ys: genres_tensor
|
||
|
};
|
||
|
}
|
||
|
}
|
||
|
|
||
|
dataset_train() {
|
||
|
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
|
||
|
}
|
||
|
|
||
|
dataset_validate() {
|
||
|
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_validate));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
export default DataPreprocessor;
|