Browse Source

It predicts!

This was a whole lot easier to implement than the Temporal CNN I'm doing 
for my PhD.....

Next task: tidy up the code, document it, and write the README

....that can wait until tomorrow though :P
master
Starbeamrainbowlabs 4 months ago
parent
commit
ab496d45de
Signed by: sbrl GPG Key ID: 1BE5172E637709C2
2 changed files with 54 additions and 14 deletions
  1. +37
    -14
      src/lib/FilmPredictor.mjs
  2. +17
    -0
      src/subcommands/predict/predict.mjs

+ 37
- 14
src/lib/FilmPredictor.mjs View File

@@ -18,27 +18,26 @@ class FilmPredictor {
}
async init(dirpath = null) {
if(!fs.existsSync(this.settings.output))
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
if(dirpath !== null)
await this.load_model(dirpath);
else
else {
if(!fs.existsSync(this.settings.output))
await fs.promises.mkdir(this.settings.output, { recursive: true, mode: 0o755 });
this.dir_checkpoints = path.join(this.settings.output, "checkpoints");
if(!fs.existsSync(this.dir_checkpoints))
await fs.promises.mkdir(this.dir_checkpoints, { recursive: true, mode: 0o755 });
this.make_model();
this.model.summary();
this.model.summary();
}
}
async load_model(dirpath) {
if(!fs.existsSync(dirpath)) {
if(!fs.existsSync(dirpath))
throw new Error(`Error: The directory ${dirpath} doesn't exist.`);
}
console.error(`>>> Loading model from '${dirpath}'`);
this.model = await this.tf.loadLayersModel(`file://${dirpath}`);
this.model = await tf.loadLayersModel(`file://${path.resolve(dirpath, "model.json")}`);
console.error(`>>> Model loading complete`);
}
@@ -119,8 +118,32 @@ class FilmPredictor {
})
}
async predict(image) {
async predict(imagefilepath) {
if(!fs.existsSync(imagefilepath))
throw new Error(`Error: No file exists at '${imagefilepath}'`);
let image_data = await fs.promises.readFile(imagefilepath);
let imagetensor = tf.tidy(() => {
let image = tf.node.decodeImage(
image_data,
3 // channels
);
return image.reshape([1, ...image.shape ]);
})
let result_array = await this.model.predict(imagetensor).array();
return this.array2genres(result_array);
}
array2genres(arr) {
let result = [];
for(let i = 0; i < arr.length; i++) {
if(arr[i] < 0.5)
continue;
result.push(genres[i]);
}
return result;
}
}



+ 17
- 0
src/subcommands/predict/predict.mjs View File

@@ -1,5 +1,22 @@
"use strict";

import fs from 'fs';

import FilmPredictor from '../../lib/FilmPredictor.mjs';

export default async function(settings) {
if(!fs.existsSync(settings.input)) {
console.error(`Error: The input file '${settings.input}' doesn't exist (did you type it correctly?)`);
process.exit(1);
}
if(!fs.existsSync(settings.ai_model)) {
console.error(`Error: Failed to locate AI model directory at '${settings.ai_model}'.`);
process.exit(1);
}
let model = new FilmPredictor(settings);
await model.init(settings.ai_model); // We're training a new model here
let result = await model.predict(settings.input);
console.log(result.join("\n"));
}

Loading…
Cancel
Save