68 lines
1.7 KiB
JavaScript
68 lines
1.7 KiB
JavaScript
"use strict";
|
|
|
|
import tf from '@tensorflow/tfjs-node';
|
|
|
|
export default function(cats_count, image_size, use_crossentropy = false, activation = "relu") {
|
|
console.error(`>>> Creating new model (activation = ${activation})`);
|
|
let model = tf.sequential();
|
|
|
|
model.add(tf.layers.conv2d({
|
|
name: "conv2d_1",
|
|
dataFormat: "channelsLast",
|
|
inputShape: [ image_size, image_size, 3 ],
|
|
kernelSize: 5,
|
|
filters: 3,
|
|
strides: 1,
|
|
activation
|
|
}));
|
|
if(image_size > 32) {
|
|
model.add(tf.layers.conv2d({
|
|
name: "conv2d_2",
|
|
dataFormat: "channelsLast",
|
|
kernelSize: 5,
|
|
filters: 3,
|
|
strides: 2,
|
|
activation
|
|
}));
|
|
model.add(tf.layers.conv2d({
|
|
name: "conv2d_3",
|
|
dataFormat: "channelsLast",
|
|
kernelSize: 5,
|
|
filters: 3,
|
|
strides: 2,
|
|
activation
|
|
}));
|
|
}
|
|
|
|
// Reshape and flatten
|
|
let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape;
|
|
model.add(tf.layers.reshape({
|
|
name: "reshape",
|
|
targetShape: [
|
|
cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3]
|
|
]
|
|
}));
|
|
model.add(tf.layers.dense({
|
|
name: "dense",
|
|
units: cats_count,
|
|
activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead
|
|
}));
|
|
|
|
let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq
|
|
metrics = [ "mse" /* meanSquaredError */ ];
|
|
|
|
if(use_crossentropy) {
|
|
console.log(`>>> Using categorical cross-entropy loss`);
|
|
loss = "categoricalCrossentropy";
|
|
metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ];
|
|
}
|
|
|
|
model.compile({
|
|
optimizer: tf.train.adam(),
|
|
loss,
|
|
metrics
|
|
});
|
|
|
|
return model;
|
|
}
|