It predicts!
This was a whole lot easier to implement than the Temporal CNN I'm doing for my PhD..... Next task: tidy up the code, document it, and write the README ....that can wait until tomorrow though :P
This commit is contained in:
parent
151d218102
commit
ab496d45de
2 changed files with 54 additions and 14 deletions
|
@ -18,27 +18,26 @@ class FilmPredictor {
|
||||||
}
|
}
|
||||||
|
|
||||||
async init(dirpath = null) {
|
async init(dirpath = null) {
|
||||||
if(!fs.existsSync(this.settings.output))
|
|
||||||
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
|
|
||||||
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
|
|
||||||
if(!fs.existsSync(this.dir_checkpoints))
|
|
||||||
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
|
||||||
|
|
||||||
|
|
||||||
if(dirpath !== null)
|
if(dirpath !== null)
|
||||||
await this.load_model(dirpath);
|
await this.load_model(dirpath);
|
||||||
else
|
else {
|
||||||
|
if(!fs.existsSync(this.settings.output))
|
||||||
|
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
|
||||||
|
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
|
||||||
|
if(!fs.existsSync(this.dir_checkpoints))
|
||||||
|
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
|
||||||
|
|
||||||
this.make_model();
|
this.make_model();
|
||||||
|
|
||||||
this.model.summary();
|
this.model.summary();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async load_model(dirpath) {
|
async load_model(dirpath) {
|
||||||
if(!fs.existsSync(dirpath)) {
|
if(!fs.existsSync(dirpath))
|
||||||
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
|
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
|
||||||
}
|
|
||||||
console.error(`>>> Loading model from '${dirpath}'`);
|
console.error(`>>> Loading model from '${dirpath}'`);
|
||||||
this.model = await this.tf.loadLayersModel(`file://${dirpath}`);
|
this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`);
|
||||||
console.error(`>>> Model loading complete`);
|
console.error(`>>> Model loading complete`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,8 +118,32 @@ class FilmPredictor {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async predict(image) {
|
async predict(imagefilepath) {
|
||||||
|
if(!fs.existsSync(imagefilepath))
|
||||||
|
throw new Error(`Error: No file exists at '${imagefilepath}'`);
|
||||||
|
|
||||||
|
let image_data = await fs.promises.readFile(imagefilepath);
|
||||||
|
|
||||||
|
let imagetensor = tf.tidy(() => {
|
||||||
|
let image = tf.node.decodeImage(
|
||||||
|
image_data,
|
||||||
|
3 // channels
|
||||||
|
);
|
||||||
|
return image.reshape([1, ...image.shape ]);
|
||||||
|
})
|
||||||
|
let result_array = await this.model.predict(imagetensor).array();
|
||||||
|
return this.array2genres(result_array);
|
||||||
|
}
|
||||||
|
|
||||||
|
array2genres(arr) {
|
||||||
|
let result = [];
|
||||||
|
for(let i = 0; i < arr.length; i++) {
|
||||||
|
if(arr[i] < 0.5)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
result.push(genres[i]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,22 @@
|
||||||
"use strict";
|
"use strict";
|
||||||
|
|
||||||
|
import fs from 'fs';
|
||||||
|
|
||||||
|
import FilmPredictor from '../../lib/FilmPredictor.mjs';
|
||||||
|
|
||||||
export default async function(settings) {
|
export default async function(settings) {
|
||||||
|
if(!fs.existsSync(settings.input)) {
|
||||||
|
console.error(`Error: The input file '${settings.input}' doesn't exist (did you type it correctly?)`);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
if(!fs.existsSync(settings.ai_model)) {
|
||||||
|
console.error(`Error: Failed to locate AI model directory at '${settings.ai_model}'.`);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let model = new FilmPredictor(settings);
|
||||||
|
await model.init(settings.ai_model); // We're training a new model here
|
||||||
|
|
||||||
|
let result = await model.predict(settings.input);
|
||||||
|
console.log(result.join("\n"));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue