"use strict"; import tf from '@tensorflow/tfjs-node'; export default function(cats_count, image_size, use_crossentropy = false) { console.error(`>>> Creating new model`); let model = tf.sequential(); model.add(tf.layers.conv2d({ name: "conv2d_1", dataFormat: "channelsLast", inputShape: [ image_size, image_size, 3 ], kernelSize: 5, filters: 3, strides: 1, activation: "relu" })); if(image_size > 32) { model.add(tf.layers.conv2d({ name: "conv2d_2", dataFormat: "channelsLast", kernelSize: 5, filters: 3, strides: 2, activation: "relu" })); model.add(tf.layers.conv2d({ name: "conv2d_3", dataFormat: "channelsLast", kernelSize: 5, filters: 3, strides: 2, activation: "relu" })); } // Reshape and flatten let cnn_stack_output_shape = model.layers[model.layers.length - 1].outputShape; model.add(tf.layers.reshape({ name: "reshape", targetShape: [ cnn_stack_output_shape[1] * cnn_stack_output_shape[2] * cnn_stack_output_shape[3] ] })); model.add(tf.layers.dense({ name: "dense", units: cats_count, activation: use_crossentropy ? "softmax" : "sigmoid" // If you're only predicting a single label at a time, then choose "softmax" instead })); let loss = "meanSquaredError", // we want the root mean squared error, but Tensorflow.js doesn't have an option for that so we'll do it in post when graphing with jq metrics = [ "mse" /* meanSquaredError */ ]; if(use_crossentropy) { console.log(`>>> Using categorical cross-entropy loss`); loss = "categoricalCrossentropy"; metrics = [ "accuracy", "categoricalCrossentropy", "categoricalAccuracy" ]; } model.compile({ optimizer: tf.train.adam(), loss, metrics }); return model; }