It predicts!

This was a whole lot easier to implement than the Temporal CNN I'm doing 
for my PhD.....

Next task: tidy up the code, document it, and write the README

....that can wait until tomorrow though :P
This commit is contained in:
Starbeamrainbowlabs 2020-09-15 19:09:42 +01:00
父節點 151d218102
當前提交 ab496d45de
簽署人: sbrl
GPG 金鑰 ID: 1BE5172E637709C2
共有 2 個檔案被更改,包括 54 行新增14 行删除

查看文件

@ -18,27 +18,26 @@ class FilmPredictor {
}
async init(dirpath = null) {
if(!fs.existsSync(this.settings.output))
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
if(dirpath !== null)
await this.load_model(dirpath);
else
else {
if(!fs.existsSync(this.settings.output))
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
this.make_model();
this.model.summary();
this.model.summary();
}
}
async load_model(dirpath) {
if(!fs.existsSync(dirpath)) {
if(!fs.existsSync(dirpath))
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
}
console.error(`>>> Loading model from '${dirpath}'`);
this.model = await this.tf.loadLayersModel(`file://${dirpath}`);
this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`);
console.error(`>>> Model loading complete`);
}
@ -119,8 +118,32 @@ class FilmPredictor {
})
}
async predict(image) {
async predict(imagefilepath) {
if(!fs.existsSync(imagefilepath))
throw new Error(`Error: No file exists at '${imagefilepath}'`);
let image_data = await fs.promises.readFile(imagefilepath);
let imagetensor = tf.tidy(() => {
let image = tf.node.decodeImage(
image_data,
3 // channels
);
return image.reshape([1, ...image.shape ]);
})
let result_array = await this.model.predict(imagetensor).array();
return this.array2genres(result_array);
}
array2genres(arr) {
let result = [];
for(let i = 0; i < arr.length; i++) {
if(arr[i] < 0.5)
continue;
result.push(genres[i]);
}
return result;
}
}

查看文件

@ -1,5 +1,22 @@
"use strict";
import fs from 'fs';
import FilmPredictor from '../../lib/FilmPredictor.mjs';
export default async function(settings) {
if(!fs.existsSync(settings.input)) {
console.error(`Error: The input file '${settings.input}' doesn't exist (did you type it correctly?)`);
process.exit(1);
}
if(!fs.existsSync(settings.ai_model)) {
console.error(`Error: Failed to locate AI model directory at '${settings.ai_model}'.`);
process.exit(1);
}
let model = new FilmPredictor(settings);
await model.init(settings.ai_model); // We're training a new model here
let result = await model.predict(settings.input);
console.log(result.join("\n"));
}