Make activation function a setting on the CLI

This commit is contained in:
Starbeamrainbowlabs 2020-12-15 19:32:03 +00:00
parent 926a82bebf
commit 10138cd310
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
4 changed files with 9 additions and 6 deletions

View file

@ -16,6 +16,7 @@ export default async function () {
.argument("output", "Path to the output directory to save the trained AI to")
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
.argument("activation", "Set the activation function for the CNN layer(s) (default: relu)", "relu", "string");
cli.subcommand("predict", "Predicts the genres of the specified image")
.argument("input", "Path to the input image")
.argument("ai-model", "Path to the saved AI model to load");

View file

@ -17,6 +17,7 @@ class FilmPredictor {
this.image_size = 256;
this.use_crossentropy = false;
this.activation = "relu";
this.batch_size = 32;
this.prefetch = 64;
@ -40,7 +41,8 @@ class FilmPredictor {
this.model = make_model(
this.cats.count,
this.image_size,
this.use_crossentropy
this.use_crossentropy,
this.activation
);
this.model.summary();

View file

@ -2,8 +2,7 @@
import tf from '@tensorflow/tfjs-node';
export default function(cats_count, image_size, use_crossentropy = false) {
console.error(`>>> Creating new model`);
export default function(cats_count, image_size, use_crossentropy = false, activation = "relu") {
let model = tf.sequential();
model.add(tf.layers.conv2d({
@ -13,7 +12,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
kernelSize: 5,
filters: 3,
strides: 1,
activation: "tanh"
activation
}));
if(image_size > 32) {
model.add(tf.layers.conv2d({
@ -22,7 +21,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
kernelSize: 5,
filters: 3,
strides: 2,
activation: "tanh"
activation
}));
model.add(tf.layers.conv2d({
name: "conv2d_3",
@ -30,7 +29,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
kernelSize: 5,
filters: 3,
strides: 2,
activation: "tanh"
activation
}));
}

View file

@ -27,6 +27,7 @@ export default async function(settings) {
let model = new FilmPredictor(settings, preprocessor.cats);
model.image_size = settings.image_size;
model.use_crossentropy = settings.cross_entropy;
model.activation = settings.activation;
await model.init(); // We're training a new model here
await model.train(