The Keras Custom Layer Explained

Ben Cook • Posted 2020-12-26 • Last updated 2021-03-24

One of the joys of deep learning is working with layers that you can stack up like Lego blocks – you get the benefit of world class research because the open source community is so robust. Keras is a great abstraction for taking advantage of this work, allowing you to build powerful models quickly. Still, sometimes you need to define your own custom layers.

Keras has changed a lot over the last several years (as has the community at large). This tutorial works for tensorflow>=1.7.0 (up to at least version 2.4.0) which includes a fairly stable version of the Keras API.

How Keras custom layers work

Layer classes store network weights and define a forward pass. Let’s start with a simple custom layer that applies two linear transformations. We’ll explain each part throughout the post.

class DoubleLinearLayer(tf.keras.layers.Layer):
    def __init__(self, n_units=8):
        super().__init__()
        self.n_units = n_units
    
    def build(self, input_shape):
        self.weights1 = self.add_weight(
            "weights1",
            shape=(int(input_shape[-1]), self.n_units),
            initializer=tf.keras.initializers.RandomNormal(),
        )
        self.weights2 = self.add_weight(
            "weights2",
            shape=(self.n_units, self.n_units),
            initializer=tf.keras.initializers.RandomNormal(),
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.weights1)
        return tf.matmul(x, self.weights2)

You can use this Layer class in any Keras model and the rest of the functionality of the API will work correctly.

Methods

Each custom Layer class must define __init__()call(), (and usually) build():

  • __init__() assigns layer-wide attributes (e.g. number of output units). If you know the input shape, you can also initialize the weights in the constructor as well.
  • call() defines the forward pass. As long as these operations are differentiable and the weights are set to be trainable, TensorFlow will handle backpropagation for you.
  • build() is not strictly required, but implementing it is a best practice. Defining this method allows you to instantiate weights lazily, which is important if you don’t know the size of the input when you initialize the custom layer.

Usage

Once implemented, you can use the layer like any other Layer class in Keras:

layer = DoubleLinearLayer()
x = tf.ones((3, 100))
layer(x)

# Returns a (3, 8) tensor

Notice: the size of the input layer (100 dimensions) is unknown when the Layer object is initialized. For this type of usage, you need to define build().

Trainable weights

By default, weights are trainable but you can override this behavior. For example, if you wanted to initialize the weights2 sub-layer to all ones and freeze it during training, you could set trainable=False in the add_weight() call:

self.weights2 = self.add_weight(
    "weights2",
    shape=(self.n_units, self.n_units),
    initializer=tf.keras.initializers.Ones(),
    trainable=False
)

This is useful if you want to freeze pre-trained weights during training.

Training vs inference mode

call() also accepts a training arg. This allows you to control the difference in execution between training and inference. Some layers like dropout and batch normalization behave differently in those two modes. Here’s a toy example:

def call(self, inputs, training=None):
    if training:
        return inputs + 1
    return inputs

Layers are composable

A useful property of Keras layers is that they’re composable. Taking advantage of this is an important way to keep your model code as DRY as possible. For example, you can create a double linear layer with the same architecture by using Dense layers instead of initializing the weights yourself:

class DoubleLinearLayer2(tf.keras.layers.Layer):
    def __init__(self, n_units=8):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(n_units, use_bias=False)
        self.dense2 = tf.keras.layers.Dense(n_units, use_bias=False)

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

Because these Dense layers take advantage of the build method, you can call the layer in the same way as DenseLinearLayer above:

layer2 = DoubleLinearLayer2()
layer2(x)

# Returns a (3, 8) tensor

Models are like layers

By the way, the Keras Layer class has the same basic API as the Model class. Model has more methods exposed (e.g. fitevaluatesave), but in the same way that a Layer can be composed of other Layers, a Model can be composed of Models and Layers. This is useful when you borrow functionality from pre-trained models:

class CustomClassifier(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.resnet = tf.keras.applications.ResNet50(include_top=False)
        self.flatten = tf.keras.layers.Flatten()
        self.head = tf.keras.layers.Dense(10, activation="softmax")
    
    def call(self, inputs):
        x = self.resnet(inputs)
        x = self.flatten(x)
        return self.head(x)

In this case, ResNet50 is a Model class, but it can be used like any other layer inside the CustomClassifier Model. Notice also that the structure of the Model class is the same as the custom layers defined above. You can run a forward pass with this custom model in the same way:

x2 = tf.ones((12, 224, 224, 3))
model = CustomClassifier()
model(x2)

# Returns a (12, 10) tensor

Fin

And that’s Keras custom layers in a nutshell. Here’s a Jupyter notebook with the code snippets in case you want to play around with them. You should also check out the Keras API documentation.