It predicts!
This was a whole lot easier to implement than the Temporal CNN I'm doing for my PhD..... Next task: tidy up the code, document it, and write the README ....that can wait until tomorrow though :P
This commit is contained in:
parent
151d218102
commit
ab496d45de
2 changed files with 54 additions and 14 deletions
|
@ -18,27 +18,26 @@ class FilmPredictor {
|
|||
}
|
||||
|
||||
async init(dirpath = null) {
|
||||
if(!fs.existsSync(this.settings.output))
|
||||
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
|
||||
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
|
||||
if(!fs.existsSync(this.dir_checkpoints))
|
||||
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
||||
|
||||
|
||||
if(dirpath !== null)
|
||||
await this.load_model(dirpath);
|
||||
else
|
||||
else {
|
||||
if(!fs.existsSync(this.settings.output))
|
||||
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
|
||||
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
|
||||
if(!fs.existsSync(this.dir_checkpoints))
|
||||
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
||||
|
||||
this.make_model();
|
||||
|
||||
this.model.summary();
|
||||
this.model.summary();
|
||||
}
|
||||
}
|
||||
|
||||
async load_model(dirpath) {
|
||||
if(!fs.existsSync(dirpath)) {
|
||||
if(!fs.existsSync(dirpath))
|
||||
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
|
||||
}
|
||||
console.error(`>>> Loading model from '${dirpath}'`);
|
||||
this.model = await this.tf.loadLayersModel(`file://${dirpath}`);
|
||||
this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`);
|
||||
console.error(`>>> Model loading complete`);
|
||||
}
|
||||
|
||||
|
@ -119,8 +118,32 @@ class FilmPredictor {
|
|||
})
|
||||
}
|
||||
|
||||
async predict(image) {
|
||||
async predict(imagefilepath) {
|
||||
if(!fs.existsSync(imagefilepath))
|
||||
throw new Error(`Error: No file exists at '${imagefilepath}'`);
|
||||
|
||||
let image_data = await fs.promises.readFile(imagefilepath);
|
||||
|
||||
let imagetensor = tf.tidy(() => {
|
||||
let image = tf.node.decodeImage(
|
||||
image_data,
|
||||
3 // channels
|
||||
);
|
||||
return image.reshape([1, ...image.shape ]);
|
||||
})
|
||||
let result_array = await this.model.predict(imagetensor).array();
|
||||
return this.array2genres(result_array);
|
||||
}
|
||||
|
||||
array2genres(arr) {
|
||||
let result = [];
|
||||
for(let i = 0; i < arr.length; i++) {
|
||||
if(arr[i] < 0.5)
|
||||
continue;
|
||||
|
||||
result.push(genres[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
"use strict";
|
||||
|
||||
export default async function(settings) {
|
||||
import fs from 'fs';
|
||||
|
||||
import FilmPredictor from '../../lib/FilmPredictor.mjs';
|
||||
|
||||
export default async function(settings) {
|
||||
if(!fs.existsSync(settings.input)) {
|
||||
console.error(`Error: The input file '${settings.input}' doesn't exist (did you type it correctly?)`);
|
||||
process.exit(1);
|
||||
}
|
||||
if(!fs.existsSync(settings.ai_model)) {
|
||||
console.error(`Error: Failed to locate AI model directory at '${settings.ai_model}'.`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
let model = new FilmPredictor(settings);
|
||||
await model.init(settings.ai_model); // We're training a new model here
|
||||
|
||||
let result = await model.predict(settings.input);
|
||||
console.log(result.join("\n"));
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue