Make activation function a setting on the CLI
This commit is contained in:
parent
926a82bebf
commit
10138cd310
4 changed files with 9 additions and 6 deletions
|
@ -16,6 +16,7 @@ export default async function () {
|
||||||
.argument("output", "Path to the output directory to save the trained AI to")
|
.argument("output", "Path to the output directory to save the trained AI to")
|
||||||
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
|
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
|
||||||
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
|
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
|
||||||
|
.argument("activation", "Set the activation function for the CNN layer(s) (default: relu)", "relu", "string");
|
||||||
cli.subcommand("predict", "Predicts the genres of the specified image")
|
cli.subcommand("predict", "Predicts the genres of the specified image")
|
||||||
.argument("input", "Path to the input image")
|
.argument("input", "Path to the input image")
|
||||||
.argument("ai-model", "Path to the saved AI model to load");
|
.argument("ai-model", "Path to the saved AI model to load");
|
||||||
|
|
|
@ -17,6 +17,7 @@ class FilmPredictor {
|
||||||
|
|
||||||
this.image_size = 256;
|
this.image_size = 256;
|
||||||
this.use_crossentropy = false;
|
this.use_crossentropy = false;
|
||||||
|
this.activation = "relu";
|
||||||
|
|
||||||
this.batch_size = 32;
|
this.batch_size = 32;
|
||||||
this.prefetch = 64;
|
this.prefetch = 64;
|
||||||
|
@ -40,7 +41,8 @@ class FilmPredictor {
|
||||||
this.model = make_model(
|
this.model = make_model(
|
||||||
this.cats.count,
|
this.cats.count,
|
||||||
this.image_size,
|
this.image_size,
|
||||||
this.use_crossentropy
|
this.use_crossentropy,
|
||||||
|
this.activation
|
||||||
);
|
);
|
||||||
|
|
||||||
this.model.summary();
|
this.model.summary();
|
||||||
|
|
|
@ -2,8 +2,7 @@
|
||||||
|
|
||||||
import tf from '@tensorflow/tfjs-node';
|
import tf from '@tensorflow/tfjs-node';
|
||||||
|
|
||||||
export default function(cats_count, image_size, use_crossentropy = false) {
|
export default function(cats_count, image_size, use_crossentropy = false, activation = "relu") {
|
||||||
console.error(`>>> Creating new model`);
|
|
||||||
let model = tf.sequential();
|
let model = tf.sequential();
|
||||||
|
|
||||||
model.add(tf.layers.conv2d({
|
model.add(tf.layers.conv2d({
|
||||||
|
@ -13,7 +12,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
|
||||||
kernelSize: 5,
|
kernelSize: 5,
|
||||||
filters: 3,
|
filters: 3,
|
||||||
strides: 1,
|
strides: 1,
|
||||||
activation: "tanh"
|
activation
|
||||||
}));
|
}));
|
||||||
if(image_size > 32) {
|
if(image_size > 32) {
|
||||||
model.add(tf.layers.conv2d({
|
model.add(tf.layers.conv2d({
|
||||||
|
@ -22,7 +21,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
|
||||||
kernelSize: 5,
|
kernelSize: 5,
|
||||||
filters: 3,
|
filters: 3,
|
||||||
strides: 2,
|
strides: 2,
|
||||||
activation: "tanh"
|
activation
|
||||||
}));
|
}));
|
||||||
model.add(tf.layers.conv2d({
|
model.add(tf.layers.conv2d({
|
||||||
name: "conv2d_3",
|
name: "conv2d_3",
|
||||||
|
@ -30,7 +29,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
|
||||||
kernelSize: 5,
|
kernelSize: 5,
|
||||||
filters: 3,
|
filters: 3,
|
||||||
strides: 2,
|
strides: 2,
|
||||||
activation: "tanh"
|
activation
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ export default async function(settings) {
|
||||||
let model = new FilmPredictor(settings, preprocessor.cats);
|
let model = new FilmPredictor(settings, preprocessor.cats);
|
||||||
model.image_size = settings.image_size;
|
model.image_size = settings.image_size;
|
||||||
model.use_crossentropy = settings.cross_entropy;
|
model.use_crossentropy = settings.cross_entropy;
|
||||||
|
model.activation = settings.activation;
|
||||||
await model.init(); // We're training a new model here
|
await model.init(); // We're training a new model here
|
||||||
|
|
||||||
await model.train(
|
await model.train(
|
||||||
|
|
Loading…
Reference in a new issue