2020-12-15 18:22:50 +00:00
"use strict" ;
import tf from '@tensorflow/tfjs-node' ;
2020-12-15 19:32:03 +00:00
export default function ( cats _count , image _size , use _crossentropy = false , activation = "relu" ) {
2020-12-15 18:22:50 +00:00
let model = tf . sequential ( ) ;
model . add ( tf . layers . conv2d ( {
name : "conv2d_1" ,
dataFormat : "channelsLast" ,
inputShape : [ image _size , image _size , 3 ] ,
kernelSize : 5 ,
filters : 3 ,
strides : 1 ,
2020-12-15 19:32:03 +00:00
activation
2020-12-15 18:22:50 +00:00
} ) ) ;
if ( image _size > 32 ) {
model . add ( tf . layers . conv2d ( {
name : "conv2d_2" ,
dataFormat : "channelsLast" ,
kernelSize : 5 ,
filters : 3 ,
strides : 2 ,
2020-12-15 19:32:03 +00:00
activation
2020-12-15 18:22:50 +00:00
} ) ) ;
model . add ( tf . layers . conv2d ( {
name : "conv2d_3" ,
dataFormat : "channelsLast" ,
kernelSize : 5 ,
filters : 3 ,
strides : 2 ,
2020-12-15 19:32:03 +00:00
activation
2020-12-15 18:22:50 +00:00
} ) ) ;
}
// Reshape and flatten
let cnn _stack _output _shape = model . layers [ model . layers . length - 1 ] . outputShape ;
model . add ( tf . layers . reshape ( {
name : "reshape" ,
targetShape : [
cnn _stack _output _shape [ 1 ] * cnn _stack _output _shape [ 2 ] * cnn _stack _output _shape [ 3 ]
]
} ) ) ;
model . add ( tf . layers . dense ( {
name : "dense" ,
units : cats _count ,
activation : use _crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
} ) ) ;
let loss = "meanSquaredError" , // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics = [ "mse" /* meanSquaredError */ ] ;
if ( use _crossentropy ) {
console . log ( ` >>> Using categorical cross-entropy loss ` ) ;
loss = "categoricalCrossentropy" ;
metrics = [ "accuracy" , "categoricalCrossentropy" , "categoricalAccuracy" ] ;
}
model . compile ( {
optimizer : tf . train . adam ( ) ,
loss ,
metrics
} ) ;
return model ;
}