Make activation function a setting on the CLI
This commit is contained in:
parent
926a82bebf
commit
10138cd310
4 changed files with 9 additions and 6 deletions
|
@ -16,6 +16,7 @@ export default async function () {
|
|||
.argument("output", "Path to the output directory to save the trained AI to")
|
||||
.argument("image-size", "The width+height of input images (images are assumed to be square; default: 256)", 256, "integer")
|
||||
.argument("cross-entropy", "Use categorical cross-entropy loss instead of mean-squared error (best if each image has only a single label)", false, "boolean")
|
||||
.argument("activation", "Set the activation function for the CNN layer(s) (default: relu)", "relu", "string");
|
||||
cli.subcommand("predict", "Predicts the genres of the specified image")
|
||||
.argument("input", "Path to the input image")
|
||||
.argument("ai-model", "Path to the saved AI model to load");
|
||||
|
|
|
@ -17,6 +17,7 @@ class FilmPredictor {
|
|||
|
||||
this.image_size = 256;
|
||||
this.use_crossentropy = false;
|
||||
this.activation = "relu";
|
||||
|
||||
this.batch_size = 32;
|
||||
this.prefetch = 64;
|
||||
|
@ -40,7 +41,8 @@ class FilmPredictor {
|
|||
this.model = make_model(
|
||||
this.cats.count,
|
||||
this.image_size,
|
||||
this.use_crossentropy
|
||||
this.use_crossentropy,
|
||||
this.activation
|
||||
);
|
||||
|
||||
this.model.summary();
|
||||
|
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
import tf from '@tensorflow/tfjs-node';
|
||||
|
||||
export default function(cats_count, image_size, use_crossentropy = false) {
|
||||
console.error(`>>> Creating new model`);
|
||||
export default function(cats_count, image_size, use_crossentropy = false, activation = "relu") {
|
||||
let model = tf.sequential();
|
||||
|
||||
model.add(tf.layers.conv2d({
|
||||
|
@ -13,7 +12,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
|
|||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 1,
|
||||
activation: "tanh"
|
||||
activation
|
||||
}));
|
||||
if(image_size > 32) {
|
||||
model.add(tf.layers.conv2d({
|
||||
|
@ -22,7 +21,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
|
|||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "tanh"
|
||||
activation
|
||||
}));
|
||||
model.add(tf.layers.conv2d({
|
||||
name: "conv2d_3",
|
||||
|
@ -30,7 +29,7 @@ export default function(cats_count, image_size, use_crossentropy = false) {
|
|||
kernelSize: 5,
|
||||
filters: 3,
|
||||
strides: 2,
|
||||
activation: "tanh"
|
||||
activation
|
||||
}));
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ export default async function(settings) {
|
|||
let model = new FilmPredictor(settings, preprocessor.cats);
|
||||
model.image_size = settings.image_size;
|
||||
model.use_crossentropy = settings.cross_entropy;
|
||||
model.activation = settings.activation;
|
||||
await model.init(); // We're training a new model here
|
||||
|
||||
await model.train(
|
||||
|
|
Loading…
Reference in a new issue