2020-09-15 17:32:51 +00:00
"use strict" ;
import path from 'path' ;
import fs from 'fs' ;
import tf from '@tensorflow/tfjs-node' ;
2020-09-23 16:23:51 +00:00
// import tf from '@tensorflow/tfjs-node-gpu';
// import tf from '@tensorflow/tfjs';
2020-09-15 17:32:51 +00:00
import genres from './Genres.mjs' ;
class FilmPredictor {
constructor ( settings ) {
this . settings = settings ;
this . genres = genres . length ;
this . batch _size = 32 ;
this . prefetch = 4 ;
}
async init ( dirpath = null ) {
if ( dirpath !== null )
await this . load _model ( dirpath ) ;
2020-09-15 18:09:42 +00:00
else {
if ( ! fs . existsSync ( this . settings . output ) )
await fs . promises . mkdir ( this . settings . output , { recursive : true , mode : 0o755 } ) ;
this . dir _checkpoints = path . join ( this . settings . output , "checkpoints" ) ;
if ( ! fs . existsSync ( this . dir _checkpoints ) )
await fs . promises . mkdir ( this . dir _checkpoints , { recursive : true , mode : 0o755 } ) ;
2020-09-15 17:32:51 +00:00
this . make _model ( ) ;
2020-09-15 18:09:42 +00:00
this . model . summary ( ) ;
}
2020-09-15 17:32:51 +00:00
}
async load _model ( dirpath ) {
2020-09-15 18:09:42 +00:00
if ( ! fs . existsSync ( dirpath ) )
2020-09-15 17:32:51 +00:00
throw new Error ( ` Error: The directory ${ dirpath } doesn't exist. ` ) ;
console . error ( ` >>> Loading model from ' ${ dirpath } ' ` ) ;
2020-09-15 18:09:42 +00:00
this . model = await tf . loadLayersModel ( ` file:// ${ path . resolve ( dirpath , "model.json" ) } ` ) ;
2020-09-15 17:32:51 +00:00
console . error ( ` >>> Model loading complete ` ) ;
}
make _model ( ) {
console . error ( ` >>> Creating new model ` ) ;
this . model = tf . sequential ( ) ;
this . model . add ( tf . layers . conv2d ( {
name : "conv2d_1" ,
dataFormat : "channelsLast" ,
inputShape : [ 256 , 256 , 3 ] ,
kernelSize : 5 ,
filters : 3 ,
strides : 2 ,
activation : "relu"
} ) ) ;
this . model . add ( tf . layers . conv2d ( {
name : "conv2d_2" ,
dataFormat : "channelsLast" ,
kernelSize : 5 ,
filters : 3 ,
strides : 2 ,
activation : "relu"
} ) ) ;
this . model . add ( tf . layers . conv2d ( {
name : "conv2d_3" ,
dataFormat : "channelsLast" ,
kernelSize : 5 ,
filters : 3 ,
strides : 2 ,
activation : "relu"
} ) ) ;
// Reshape and flatten
let cnn _stack _output _shape = this . model . getLayer ( "conv2d_3" ) . outputShape ;
this . model . add ( tf . layers . reshape ( {
name : "reshape" ,
targetShape : [
cnn _stack _output _shape [ 1 ] * cnn _stack _output _shape [ 2 ] * cnn _stack _output _shape [ 3 ]
]
} ) ) ;
this . model . add ( tf . layers . dense ( {
name : "dense" ,
units : this . genres ,
activation : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
} ) ) ;
this . model . compile ( {
optimizer : tf . train . adam ( ) ,
loss : "meanSquaredError" , // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics : [ "mse" /* meanSquaredError */ ]
} ) ;
}
async train ( dataset _train , dataset _validate ) {
dataset _train = dataset _train . batch ( this . batch _size ) . prefetch ( this . prefetch ) ;
dataset _validate = dataset _validate . batch ( this . batch _size ) . prefetch ( this . prefetch ) ;
await this . model . fitDataset ( dataset _train , {
epochs : 50 ,
verbose : 1 ,
validationData : dataset _validate ,
yieldEvery : "batch" ,
shuffle : false ,
callbacks : {
onEpochEnd : async ( epoch , metrics ) => {
console . error ( ` >>> Epoch ${ epoch } complete, metrics: ` , metrics ) ;
let dir _output _checkpoint = ` file:// ${ path . join ( this . dir _checkpoints , ` ${ epoch . toString ( ) } ` ) } ` ;
await Promise . all ( [
this . model . save ( dir _output _checkpoint ) ,
fs . promises . appendFile (
path . join ( this . settings . output , ` metrics.stream.json ` ) , JSON . stringify ( metrics ) + "\n"
)
] ) ;
}
}
} )
}
2020-09-15 18:09:42 +00:00
async predict ( imagefilepath ) {
if ( ! fs . existsSync ( imagefilepath ) )
throw new Error ( ` Error: No file exists at ' ${ imagefilepath } ' ` ) ;
2020-09-15 17:32:51 +00:00
2020-09-15 18:09:42 +00:00
let image _data = await fs . promises . readFile ( imagefilepath ) ;
let imagetensor = tf . tidy ( ( ) => {
let image = tf . node . decodeImage (
image _data ,
3 // channels
) ;
return image . reshape ( [ 1 , ... image . shape ] ) ;
} )
2020-09-15 18:18:20 +00:00
let result _array = ( await this . model . predict ( imagetensor ) . array ( ) ) [ 0 ] ;
2020-09-15 18:09:42 +00:00
return this . array2genres ( result _array ) ;
}
array2genres ( arr ) {
let result = [ ] ;
for ( let i = 0 ; i < arr . length ; i ++ ) {
if ( arr [ i ] < 0.5 )
continue ;
result . push ( genres [ i ] ) ;
}
return result ;
2020-09-15 17:32:51 +00:00
}
}
export default FilmPredictor ;