Make activation function a setting on the CLI

This commit is contained in:
Starbeamrainbowlabs 2020-12-15 19:32:03 +00:00
parent 926a82bebf
commit 10138cd310
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
4 changed files with 9 additions and 6 deletions

View file

@ -16,6 +16,7 @@ export default async function () {
.argument("output", "Path to the output directory to save the trained AI to") .argument("output", "Path to the output directory to save the trained AI to")
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer") .argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean") .argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
.argument("activation", "Set the activation function for the CNN layer(s) (default: relu)", "relu", "string");
cli.subcommand("predict", "Predicts the genres of the specified image") cli.subcommand("predict", "Predicts the genres of the specified image")
.argument("input", "Path to the input image") .argument("input", "Path to the input image")
.argument("ai-model", "Path to the saved AI model to load"); .argument("ai-model", "Path to the saved AI model to load");

View file

@ -17,6 +17,7 @@ class FilmPredictor {
this.image_size = 256; this.image_size = 256;
this.use_crossentropy = false; this.use_crossentropy = false;
this.activation = "relu";
this.batch_size = 32; this.batch_size = 32;
this.prefetch = 64; this.prefetch = 64;
@ -40,7 +41,8 @@ class FilmPredictor {
this.model = make_model( this.model = make_model(
this.cats.count, this.cats.count,
this.image_size, this.image_size,
this.use_crossentropy this.use_crossentropy,
this.activation
); );
this.model.summary(); this.model.summary();

View file

@ -2,8 +2,7 @@
import tf from '@tensorflow/tfjs-node'; import tf from '@tensorflow/tfjs-node';
export default function(cats_count, image_size, use_crossentropy = false) { export default function(cats_count, image_size, use_crossentropy = false, activation = "relu") {
console.error(`>>> Creating new model`);
let model = tf.sequential(); let model = tf.sequential();
model.add(tf.layers.conv2d({ model.add(tf.layers.conv2d({
@ -13,7 +12,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
kernelSize: 5, kernelSize: 5,
filters: 3, filters: 3,
strides: 1, strides: 1,
activation: "tanh" activation
})); }));
if(image_size > 32) { if(image_size > 32) {
model.add(tf.layers.conv2d({ model.add(tf.layers.conv2d({
@ -22,7 +21,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
kernelSize: 5, kernelSize: 5,
filters: 3, filters: 3,
strides: 2, strides: 2,
activation: "tanh" activation
})); }));
model.add(tf.layers.conv2d({ model.add(tf.layers.conv2d({
name: "conv2d_3", name: "conv2d_3",
@ -30,7 +29,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
kernelSize: 5, kernelSize: 5,
filters: 3, filters: 3,
strides: 2, strides: 2,
activation: "tanh" activation
})); }));
} }

View file

@ -27,6 +27,7 @@ export default async function(settings) {
let model = new FilmPredictor(settings, preprocessor.cats); let model = new FilmPredictor(settings, preprocessor.cats);
model.image_size = settings.image_size; model.image_size = settings.image_size;
model.use_crossentropy = settings.cross_entropy; model.use_crossentropy = settings.cross_entropy;
model.activation = settings.activation;
await model.init(); // We're training a new model here await model.init(); // We're training a new model here
await model.train( await model.train(