One of the joys of deep learning is working with layers that you can stack up like Lego blocks – you get the benefit of world class research because the open source community is so robust. Keras is a great abstraction for taking advantage of this work, allowing you to build powerful models quickly. Still, sometimes you need to define your own custom layers.
Keras has changed a lot over the last several years (as has the community at large). This tutorial works for
tensorflow>=1.7.0 (up to at least version 2.4.0) which includes a fairly stable version of the Keras API.
How Keras custom layers work
Layer classes store network weights and define a forward pass. Let’s start with a simple custom layer that applies two linear transformations. We’ll explain each part throughout the post.
class DoubleLinearLayer(tf.keras.layers.Layer): def __init__(self, n_units=8): super().__init__() self.n_units = n_units def build(self, input_shape): self.weights1 = self.add_weight( "weights1", shape=(int(input_shape[-1]), self.n_units), initializer=tf.keras.initializers.RandomNormal(), ) self.weights2 = self.add_weight( "weights2", shape=(self.n_units, self.n_units), initializer=tf.keras.initializers.RandomNormal(), ) def call(self, inputs): x = tf.matmul(inputs, self.weights1) return tf.matmul(x, self.weights2)
You can use this Layer class in any Keras model and the rest of the functionality of the API will work correctly.
Each custom Layer class must define
call(), (and usually)
__init__()assigns layer-wide attributes (e.g. number of output units). If you know the input shape, you can also initialize the weights in the constructor as well.
call()defines the forward pass. As long as these operations are differentiable and the weights are set to be trainable, TensorFlow will handle backpropagation for you.
build()is not strictly required, but implementing it is a best practice. Defining this method allows you to instantiate weights lazily, which is important if you don’t know the size of the input when you initialize the custom layer.
Once implemented, you can use the layer like any other Layer class in Keras:
layer = DoubleLinearLayer() x = tf.ones((3, 100)) layer(x) # Returns a (3, 8) tensor
Notice: the size of the input layer (100 dimensions) is unknown when the Layer object is initialized. For this type of usage, you need to define
By default, weights are trainable but you can override this behavior. For example, if you wanted to initialize the
weights2 sub-layer to all ones and freeze it during training, you could set
trainable=False in the
self.weights2 = self.add_weight( "weights2", shape=(self.n_units, self.n_units), initializer=tf.keras.initializers.Ones(), trainable=False )
This is useful if you want to freeze pre-trained weights during training.
Training vs inference mode
call() also accepts a training arg. This allows you to control the difference in execution between training and inference. Some layers like dropout and batch normalization behave differently in those two modes. Here’s a toy example:
def call(self, inputs, training=None): if training: return inputs + 1 return inputs
Layers are composable
A useful property of Keras layers is that they’re composable. Taking advantage of this is an important way to keep your model code as DRY as possible. For example, you can create a double linear layer with the same architecture by using
Dense layers instead of initializing the weights yourself:
class DoubleLinearLayer2(tf.keras.layers.Layer): def __init__(self, n_units=8): super().__init__() self.dense1 = tf.keras.layers.Dense(n_units, use_bias=False) self.dense2 = tf.keras.layers.Dense(n_units, use_bias=False) def call(self, inputs): x = self.dense1(inputs) return self.dense2(x)
Dense layers take advantage of the
build method, you can call the layer in the same way as
layer2 = DoubleLinearLayer2() layer2(x) # Returns a (3, 8) tensor
Models are like layers
By the way, the Keras Layer class has the same basic API as the Model class. Model has more methods exposed (e.g.
save), but in the same way that a Layer can be composed of other Layers, a Model can be composed of Models and Layers. This is useful when you borrow functionality from pre-trained models:
class CustomClassifier(tf.keras.Model): def __init__(self): super().__init__() self.resnet = tf.keras.applications.ResNet50(include_top=False) self.flatten = tf.keras.layers.Flatten() self.head = tf.keras.layers.Dense(10, activation="softmax") def call(self, inputs): x = self.resnet(inputs) x = self.flatten(x) return self.head(x)
In this case,
ResNet50 is a Model class, but it can be used like any other layer inside the
CustomClassifier Model. Notice also that the structure of the Model class is the same as the custom layers defined above. You can run a forward pass with this custom model in the same way:
x2 = tf.ones((12, 224, 224, 3)) model = CustomClassifier() model(x2) # Returns a (12, 10) tensor