Source Code Layout for Machine Learning Pipelines

Ben Cook • Posted 2019-03-12 • Last updated 2021-10-21

When I look at a new open source deep learning project, I start with several questions.

  • What’s the structure of the model?
  • How’s the model trained?
  • How’s the training data formatted? How’s it preprocessed?
  • How can I do inference with a trained model?

But for a machine learning pipeline to be useful to me in a real world scenario, all of the above are table stakes. There’s no way to make progress on model architecture or hyperparameter optimization until these questions are well understood.

And although a ton of machine learning progress is being made in a transparent way, many research-focused repositories obfuscate the answer to these basic questions.

The goal of this post is to use a toy problem (MNIST) to propose a specific source code layout for machine learning projects. Any convention needs to be flexible enough to handle complex network architectures and training procedures, while also being consistent enough to speed up real-world ML development.

This pipeline trains an MNIST classifier from scratch and deploys versioned weight files to S3. The idea is that these weight files could be used for transfer learning in future pipelines (DAGs all the way down). So in this case, uploading to S3 will be considered deploying the model. You can find the source code here.

Here’s what I propose:

<project name>/
    <whatever else you need>
<whatever else you need>

The first thing to notice is that this is a Python package. This prevents us from needing to solve code distribution in an ad hoc way and it allows us to expose a command-line interface for free. It also allows us to ensure that all users install the right dependencies without needing to write a page of docs on how to get the model running locally. If you build private ML pipelines, just setup a private PyPI repository and get in the habbit of automatically publishing new packages when you push to GitHub.

Inside the Python package, we have several modules. Although this list of modules is not exhaustive, the names of these modules are intentional. I’ll describe them below.


The module is the star of the show. This should be the first thing people look at when they start looking through the code for your ML pipeline. And ideally, they should be able to get a high level grasp of what’s going on very quickly by looking at Of course, most models are more complex than an MNIST digit classifier so if this module gets too long to understand quickly, it can be replaced by a subpackage:

<project name>/
        <whatever else you need>

This should work well as long as it’s clear to visitors which file they should look at to get a high-level understanding of the model. We can accomplish this by exposing the right function or class in model/ and by using sensible module names.

For the digit classifier, I used Keras’ functional API to define a model (tf.keras.Model) and split up the feature extractor and classifier. Separating the feature extractor allows me to export weights for a convolutional stack that doesn’t require a specific image shape. The classifier does require a specific image shape because of the dense layers at the end.

The feature and classifier functions take an instance of my config class (more on this below) and define the model layers. They also optionally take a pretrained boolean argument which says whether to load publicly available pretrained weights for the model. This makes inference easy:

import tensorflow as tf
from mnist import mnist_classifier

(images, _), _ = tf.keras.datasets.mnist.load_data()
x = images[0]

model = mnist_classifier(pretrained=True)
MNIST 5 example
>>> model.predict(x[None, ..., None] / 255).argmax()


Take a look at here for the MNIST classifier. Do you understand it in 30 seconds or less?


The module tells us how the model is trained. If we needed to define a custom training loop, that would go here. But this can also just be the code required to delegate training to some other tool that has solved the problem better than we can – the cloud providers are actively innovating in this space so don’t be afraid to let them handle training for you if it fits your use case. For this toy digit classifier, I just call That means the majority of the train_model() function is setting up training and handling the result:

  1. Instantiate input and label tensors for the train and test sets.
  2. Call Model.compile() with the loss function and optimizer.
  3. Call and let Keras train the model.
  4. Evaluate the trained model on the test set to get accuracy.
  5. Save the feature extractor and classifier weights.
  6. Print accuracy.

One important note about training is that it’s designed to be used through the CLI. In fact, train_model() is not exposed in

The main reason training should primarily be considered a script and not a function to be called in downstream code is that manually training a model (say in a Jupyter notebook) prevents a result from being totally reproducible. If model weights are always a direct result of something in the version controlled repository, then previous results can be replicated by checking out an old version of code. To be clear, I’m not saying Jupyter notebooks don’t have a place in the ML development lifecycle, but I am saying that no production models should be trained in an ad hoc way.

Dataset and Preprocessing

To find out how is the training data formatted? And how is it preprocessed? we should look at For the MNIST pipeline, this module has two responsibilites.

  1. Convert raw data into train and test TFRecord datasets
  2. Create a generator of preprocessed data for training and evaluation

The raw training and test data come pre-packaged with Keras, using tf.keras.datasets.mnist.load_data(). This function returns two (x, y) tuples (one for train and one for test) where x is a sequence of 28×28 grayscale images and y is a 1-dimensional array of class indices.

The save_datasets() function uses a tf.python_io.TFRecordWriter object to write serialized image, label pairs to a TFRecord file. Similar to train_model(), this function is only called through the CLI.

The load_dataset() function uses the TensorFlow Dataset API to read records from the TFRecord files, shuffle them, divide them into batches and then preprocess them. This happens in a background process with a configurable number of background threads.

Preprocessing is defined in to make it easy to see what preprocessing operations need to occur, but the actual work is done inside the dataset generator so that it doesn’t slow down the main thread. The preprocess_images() and preprocess_labels() are exposed in the public API to make inference as easy as possible.

TODO: Support Numpy arrays in preprocessing. One thing I didn’t do, which would be useful is to make the preprocessing functions capable of handling either tf.Tensor objects or np.ndarray objects since inference can be done with Numpy arrays directly in Keras. This would make inference one step easier.

Files and Versioning

For better or worse, a big part of the problem with ML pipelines is file storage. How do you store raw data, intermediate datasets, weight files, config files, etc.? Furthermore, as code changes, those files will sometimes need to change as well.

My view is that different files should be handled differently:

  • As part of the settings required to achieve model results, the config file needs to be version controlled alongside code.
  • As described above, the raw data in the MNIST pipeline is handled by Keras. In other cases, the raw data might come from a labeling service.
  • Because I use the TensorFlow Dataset API, I need to store train.tfrecord and test.tfrecord files locally to feed training and evaluation. These are stored in the artifact directory which defaults to $HOME/.mlpipes/mnist but can be overridden. Because these are quick to generate, I don’t store them remotely.

As the main output of the pipeline, the weights files are special and therefore worth describing in more detail. After model training, these files are stored locally in the artifact directory. From there, they are deployed by uploading them to S3 as part of the build process. These files are also versioned in the same way that the Python package is versioned, using Every time the version gets bumped, new weight files that get deployed will have this new version.

Exposing Functions

After pip installing the package, there are two ways to interact with the code:

  1. Import functions in downstream packages or notebooks.
  2. Execute commands with the CLI.

Importing functions in future code is the most common use case for Python packages. Importing functions from an ML pipeline package would be useful for inference (which should work out of the box) and for inheriting some pipeline functionality in future pipelines. For example, in a future pipeline I might want to build an English letter classifier that takes features from the digit classifier and replaces the last layers (which depend on the number of classes). I could import the MnistFeatures model, load the pretrained weights and avoid repeating work in that pipeline. As a matter of fact, I plan to write this pipeline next.

But other functionality such as training and dataset generation is better exposed through a CLI. Fortunately, Google’s Python Fire package makes this easy. In order to turn the pipeline package into a CLI, I added a module which instantiates fire.Fire() with the functions that I want to expose.

import fire

from .dataset import save_datasets
from .train import train_model

def main():
    """Expose CLI functions."""
        'save-datasets': save_datasets,
        'train-model': train_model,

This main() function is used by the entry_points argument in the setup() function in

        'console_scripts': [
            'mnist = mnist.__main__:main',

This little trick allows a user to call the exposed functions with a command-line tool called mnist. For example, mnist save-datasets will save the train and test datasets to the artifact directory. mnist save-datasets /path/to/config.yml will save datasets using any important overrides defined in a config.yml file and mnist save-datasets -- --help prints the docstring and usage instructions to the terminal. Neat!


Every neural network has configurable values that have to be set. Using a dataclass to define these values in code makes it easy to access these values in any other module.

One advantage of this approach is the ability to define derived values programmitcally. For example, when we pass a TensorFlow dataset to the Keras method, we need to tell Keras how many steps are in an epoch. The steps per epoch should be the total number of examples divided by the batch size. Since we already defined batch size, we can add a steps_per_epoch property to our dataclass.

def steps_per_epoch(self) -> int:
    return self.n_train_samples // self.batch_size

This can be accessed on the instantiated config object with config.steps_per_epoch.

Although defining config in code allows us to be explicit about what is configurable throughout the package, config files are useful when using the CLI. The compromise struck here is to use a YAML file to pass in overrides to the default config values defined in code. A MnistConfig.from_yaml() classmethod takes the path to a YAML file and instantiates a new config object with all the values in the file overriding defaults. This allows the config file to be terse – only define the values that deviate from the defaults. The config file I’m currently using defines only two values:

# Config overrides
# See mnist/ for defaults
batch_size: 64
n_epochs: 2


In summary, I want to make a few points about source code layout for ML pipelines:

  • Your ML pipeline should be a Python package.
  • Inference should be easy for newcomers.
  • Expose pipeline commands through a CLI.
  • Use simple module names so people know where to look.

Because code speaks louder than words, you should checkout the project on GitHub here. A couple major areas I didn’t cover are testing and continuous integration. You can get a feel for some of my thoughts on these subjects by looking at the MNIST pipeline. I will plan to go into more depth in future posts.