87 lines
2.6 KiB
JavaScript
87 lines
2.6 KiB
JavaScript
"use strict";
|
|
|
|
import fs from 'fs';
|
|
import path from 'path';
|
|
|
|
import tf from '@tensorflow/tfjs-node';
|
|
|
|
import Categories from './Categories.mjs';
|
|
|
|
class DataPreprocessor {
|
|
constructor(dir_input) {
|
|
if(!fs.existsSync(dir_input))
|
|
throw new Error(`Error: The input directory '${dir_input}' does not exist.`);
|
|
|
|
this.dir_input_train = path.join(dir_input, "train");
|
|
this.dir_input_validate = path.join(dir_input, "validate");
|
|
this.file_input_cats = path.join(dir_input, "categories.txt");
|
|
|
|
if(!fs.existsSync(this.dir_input_train))
|
|
throw new Error(`Error: Failed to locate the directory containing the training data.`);
|
|
if(!fs.existsSync(this.dir_input_validate))
|
|
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
|
|
|
|
this.cats = new Categories(this.file_input_cats);
|
|
}
|
|
|
|
async *data_from_dir(dirpath) {
|
|
if(!fs.existsSync(dirpath))
|
|
throw new Error(`[DataPreprocessor/dataset_dir] Error: The specified directory does not exist.`);
|
|
|
|
let filenames = await fs.promises.readdir(dirpath);
|
|
for(let filename of filenames) {
|
|
let imagetensor = tf.node.decodeImage(
|
|
await fs.promises.readFile(path.join(dirpath, filename)),
|
|
3 // channels
|
|
);
|
|
if(imagetensor.shape[0] == 256
|
|
&& imagetensor.shape[1] == 256
|
|
&& imagetensor.shape[0] == 3)
|
|
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
|
|
|
|
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
|
|
let next_cats = filename.replace(/\.[a-zA-Z]+$/, "")
|
|
.split(",")
|
|
.slice(1)
|
|
.map(this.cats.to_id.bind(this.cats));
|
|
|
|
|
|
let pretensor = Array(this.cats.count).fill(0);
|
|
for(let cat_id of next_cats)
|
|
pretensor[cat_id] = 1;
|
|
|
|
let cats_tensor = tf.tensor(pretensor);
|
|
|
|
// console.log(`>>>>>>>>> output shapes: [${imagetensor.shape}], [${cats_tensor.shape}]`);
|
|
|
|
let result = {
|
|
xs: imagetensor,
|
|
ys: cats_tensor
|
|
};
|
|
// console.log(`[DEBUG] yielding xs`, result.xs.dataSync());
|
|
// console.log(`[DEBUG] ys`); result.ys.print();
|
|
// console.log(`--------------------`);
|
|
yield result;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns a Categories object instance that represents the list of
|
|
* categories.
|
|
* @return {Categories}
|
|
*/
|
|
categories() {
|
|
return this.cats;
|
|
}
|
|
|
|
dataset_train() {
|
|
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
|
|
}
|
|
|
|
dataset_validate() {
|
|
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_validate));
|
|
}
|
|
}
|
|
|
|
export default DataPreprocessor;
|