film-poster-genres/src/lib/DataPreprocessor.mjs

81 lines
2.3 KiB
JavaScript
Raw Normal View History

2020-09-15 17:32:51 +00:00
"use strict";
import fs from 'fs';
import path from 'path';
import tf from '@tensorflow/tfjs-node';
import genres from './Genres.mjs';
class DataPreprocessor {
constructor(dir_input) {
if(!fs.existsSync(dir_input))
throw new Error(`Error: The input directory '${dir_input}' does not exist.`);
this.dir_input_train = path.join(dir_input, "train");
this.dir_input_validate = path.join(dir_input, "validate");
if(!fs.existsSync(this.dir_input_train))
throw new Error(`Error: Failed to locate the directory containing the training data.`);
if(!fs.existsSync(this.dir_input_validate))
throw new Error(`Error: Failed to locate the directory containing the validation data.`);
this.genres_count = genres.length;
}
genre2id(genre_name) {
return genres.indexOf(genre_name);
}
id2genre(id) {
return genres[id];
}
async *data_from_dir(dirpath) {
if(!fs.existsSync(dirpath))
throw new Error(`[DataPreprocessor/dataset_dir] Error: The specified directory does not exist.`);
let filenames = await fs.promises.readdir(dirpath);
for(let filename of filenames) {
let imagetensor = tf.node.decodeImage(
await fs.promises.readFile(path.join(dirpath, filename)),
3 // channels
);
if(imagetensor.shape[0] == 256
&& imagetensor.shape[1] == 256
&& imagetensor.shape[0] == 3)
throw new Error(`Error: Loaded image has a shape of [${imagetensor.shape.join(", ")}], but a shape of [ 256, 256, 3 ] was expected.`);
// Strip the file extension, then split into an array of genres, and finally remove the id from the beginning
let next_genres = filename.replace(/\.[a-zA-Z]+$/, "")
.split(",")
.slice(1)
.map(this.genre2id);
let genres_pretensor = Array(this.genres_count).fill(0);
for(let genre_id of next_genres)
genres_pretensor[genre_id] = 1;
let genres_tensor = tf.tensor(genres_pretensor);
// console.log(`>>>>>>>>> output shapes: ${imagetensor.shape}, ${genres_tensor.shape}`);
yield {
xs: imagetensor,
ys: genres_tensor
};
}
}
dataset_train() {
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_train));
}
dataset_validate() {
return tf.data.generator(this.data_from_dir.bind(this, this.dir_input_validate));
}
}
export default DataPreprocessor;