film-poster-genres/src/lib/model.mjs

68 lines
1.7 KiB
JavaScript

"use strict";
import tf from '@tensorflow/tfjs-node';
export default function(cats_count, image_size, use_crossentropy = false) {
console.error(`>>> Creating new model`);
let model = tf.sequential();
model.add(tf.layers.conv2d({
name: "conv2d_1",
dataFormat: "channelsLast",
inputShape: [ image_size, image_size, 3 ],
kernelSize: 5,
filters: 3,
strides: 1,
activation: "relu"
}));
if(image_size > 32) {
model.add(tf.layers.conv2d({
name: "conv2d_2",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
model.add(tf.layers.conv2d({
name: "conv2d_3",
dataFormat: "channelsLast",
kernelSize: 5,
filters: 3,
strides: 2,
activation: "relu"
}));
}
// Reshape and flatten
let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape;
model.add(tf.layers.reshape({
name: "reshape",
targetShape: [
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
]
}));
model.add(tf.layers.dense({
name: "dense",
units: cats_count,
activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
}));
let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
metrics = [ "mse" /* meanSquaredError */ ];
if(use_crossentropy) {
console.log(`>>> Using categorical cross-entropy loss`);
loss = "categoricalCrossentropy";
metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ];
}
model.compile({
optimizer: tf.train.adam(),
loss,
metrics
});
return model;
}