One of the joys of deep learning is working with layers that you can stack up like Lego blocks – you get the benefit of world class research because the open source community is so robust. Keras is a great abstraction for taking advantage of this work, allowing you to build powerful models quickly. Still, sometimes you need to define your own custom layers.
Keras has changed a lot over the last several years (as has the community at large). This tutorial works for tensorflow>=1.7.0
(up to at least version 2.4.0) which includes a fairly stable version of the Keras API.
How Keras custom layers work
Layer classes store network weights and define a forward pass. Let’s start with a simple custom layer that applies two linear transformations. We’ll explain each part throughout the post.
class DoubleLinearLayer(tf.keras.layers.Layer):
def __init__(self, n_units=8):
super().__init__()
self.n_units = n_units
def build(self, input_shape):
self.weights1 = self.add_weight(
"weights1",
shape=(int(input_shape[-1]), self.n_units),
initializer=tf.keras.initializers.RandomNormal(),
)
self.weights2 = self.add_weight(
"weights2",
shape=(self.n_units, self.n_units),
initializer=tf.keras.initializers.RandomNormal(),
)
def call(self, inputs):
x = tf.matmul(inputs, self.weights1)
return tf.matmul(x, self.weights2)
You can use this Layer class in any Keras model and the rest of the functionality of the API will work correctly.
Methods
Each custom Layer class must define __init__()
, call()
, (and usually) build()
:
__init__()
assigns layer-wide attributes (e.g. number of output units). If you know the input shape, you can also initialize the weights in the constructor as well.call()
defines the forward pass. As long as these operations are differentiable and the weights are set to be trainable, TensorFlow will handle backpropagation for you.build()
is not strictly required, but implementing it is a best practice. Defining this method allows you to instantiate weights lazily, which is important if you don’t know the size of the input when you initialize the custom layer.
Usage
Once implemented, you can use the layer like any other Layer class in Keras:
layer = DoubleLinearLayer()
x = tf.ones((3, 100))
layer(x)
# Returns a (3, 8) tensor
Notice: the size of the input layer (100 dimensions) is unknown when the Layer object is initialized. For this type of usage, you need to define build()
.
Trainable weights
By default, weights are trainable but you can override this behavior. For example, if you wanted to initialize the weights2
sub-layer to all ones and freeze it during training, you could set trainable=False
in the add_weight()
call:
self.weights2 = self.add_weight(
"weights2",
shape=(self.n_units, self.n_units),
initializer=tf.keras.initializers.Ones(),
trainable=False
)
This is useful if you want to freeze pre-trained weights during training.
Training vs inference mode
call()
also accepts a training arg. This allows you to control the difference in execution between training and inference. Some layers like dropout and batch normalization behave differently in those two modes. Here’s a toy example:
def call(self, inputs, training=None):
if training:
return inputs + 1
return inputs
Layers are composable
A useful property of Keras layers is that they’re composable. Taking advantage of this is an important way to keep your model code as DRY as possible. For example, you can create a double linear layer with the same architecture by using Dense
layers instead of initializing the weights yourself:
class DoubleLinearLayer2(tf.keras.layers.Layer):
def __init__(self, n_units=8):
super().__init__()
self.dense1 = tf.keras.layers.Dense(n_units, use_bias=False)
self.dense2 = tf.keras.layers.Dense(n_units, use_bias=False)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
Because these Dense
layers take advantage of the build
method, you can call the layer in the same way as DenseLinearLayer
above:
layer2 = DoubleLinearLayer2()
layer2(x)
# Returns a (3, 8) tensor
Models are like layers
By the way, the Keras Layer class has the same basic API as the Model class. Model has more methods exposed (e.g. fit
, evaluate
, save
), but in the same way that a Layer can be composed of other Layers, a Model can be composed of Models and Layers. This is useful when you borrow functionality from pre-trained models:
class CustomClassifier(tf.keras.Model):
def __init__(self):
super().__init__()
self.resnet = tf.keras.applications.ResNet50(include_top=False)
self.flatten = tf.keras.layers.Flatten()
self.head = tf.keras.layers.Dense(10, activation="softmax")
def call(self, inputs):
x = self.resnet(inputs)
x = self.flatten(x)
return self.head(x)
In this case, ResNet50
is a Model class, but it can be used like any other layer inside the CustomClassifier
Model. Notice also that the structure of the Model class is the same as the custom layers defined above. You can run a forward pass with this custom model in the same way:
x2 = tf.ones((12, 224, 224, 3))
model = CustomClassifier()
model(x2)
# Returns a (12, 10) tensor
Fin
And that’s Keras custom layers in a nutshell. Here’s a Jupyter notebook with the code snippets in case you want to play around with them. You should also check out the Keras API documentation.