REST API Serving

Here is a short example on how to define a simple model in Neuraxle, and then how to wrap that model for serving predictions over a REST API. This is done in a few steps:

  1. Create, train, and evaluate your model.

  2. Create an object to transform your JSON inputs to data inputs, and an object to transform your predictions to JSON responses.

  3. Finally serve your pipeline as a REST API.

Note that it’d be even better to have your pipeline serialized (with our proper serialization and saving techniques) such that your app doesn’t need to retrain every time. You could also add caching with our caching wrappers for optimisations of your whole pipeline or even specific parts of it.

Import Packages

Let’s begin.

import numpy as np
from sklearn.cluster import KMeans
from sklearn.datasets import load_boston
from sklearn.decomposition import PCA, FastICA
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

from neuraxle.api.flask import FlaskRestApiWrapper, JSONDataBodyDecoder, JSONDataResponseEncoder
from neuraxle.pipeline import Pipeline
from neuraxle.steps.sklearn import RidgeModelStacking
from neuraxle.union import AddFeatures

Load your Dataset

Here, we’ll simply use the Boston Housing Dataset, and do a train test split.

boston = load_boston()
X, y = shuffle(,, random_state=13)
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, shuffle=False)

Create your Pipeline

This is a simple pipeline with model stacking and clustering preprocessing where: 1. Decomposition/clustering of the dataset is made and concatenated to the data itself as additional features. 2. A regression model and a clustering model are fitted on the resulting data. 3. Finally, a Ridge regression (similar to a linear regression) is stacked on top of the regression and clustering to re-do a final regression.

pipeline = Pipeline([

Let’s Train and Test

Yay. As usual.

print("Fitting on train:")
pipeline =, y_train)

print("Transforming train and test:")
y_train_predicted = pipeline.transform(X_train)
y_test_predicted = pipeline.transform(X_test)

print("Evaluating transformed train:")
score = r2_score(y_train_predicted, y_train)
print('R2 regression score:', score)

print("Evaluating transformed test:")
score = r2_score(y_test_predicted, y_test)
print('R2 regression score:', score)
Fitting on train:

Transforming train and test:

Evaluating transformed train:
R2 regression score: 0.9800768559459768

Evaluating transformed test:
R2 regression score: 0.9214823389806873

Deploy the Pipeline

Write a step to decode the accepted JSON as data inputs

Creating a CustomJSONDecoderFor2DArray class that maps the request body json to the expected data inputs format to send that to the pipeline.

class CustomJSONDecoderFor2DArray(JSONDataBodyDecoder):
    """This is a custom JSON decoder class that precedes the pipeline's transformation."""

    def decode(self, data_inputs):
        Transform a JSON list object into an np.array object.

        :param data_inputs: json object
        :return: np array for data inputs
        return np.array(data_inputs)

Write a step to encode the returned JSON response

Creating a CustomJSONEncoderOfOutputs that’ll return JSON so as to encode that as a flask HTTP Response object containing the predictions.

class CustomJSONEncoderOfOutputs(JSONDataResponseEncoder):
    """This is a custom JSON response encoder class for converting the pipeline's transformation outputs."""

    def encode(self, data_inputs) -> dict:
        Convert predictions to a dict for creating a JSON Response object.

        :param data_inputs:
        return {
            'predictions': list(data_inputs)

Finally Serve Predictions

Using and running a FlaskRestApiWrapper class on the pipeline that’ll be surrounded by the encoders and decoders will create a flask app that’ll calls the wrapped pipeline’s transform method on each get HTTP request:

app = FlaskRestApiWrapper(
).get_app(), port=5000)
 * Serving Flask app "neuraxle.api.flask" (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on (Press CTRL+C to quit)

API Call Example

Here is an example of how you can call your web server.

import json
import urllib

req = urllib.request.Request(
    headers={'content-type': 'application/json'},
response = urllib.request.urlopen(req)
test_predictions = json.loads(

{'predictions': [19.873746987307726, 20.390153832383643, 22.97201871551186, 23.874400960304612, 10.968518175838302, 44.652760694913646, 13.816282777815985, 23.344520152369984, 28.486888607231197, 18.56506013715421, 23.7594614258383, 20.191176891772432, 36.2769592032176, 36.25702640450678, 19.371339162786125, 32.16858040891128, 14.313426671008866, 22.790168769635873, 32.06766358029434, 8.144263755321923, 28.904093430857845, 34.06953848967775, 35.17431399552918, 7.194525286872402, 20.274565060209937, 30.897752591762448, 46.45969965918759, 10.315137442897658, 21.346127624675642, 26.317699355096448, 10.564530840678437, 21.05444722670945, 20.384796887083873, 21.19450756202019, 16.58468318354116, 15.091556639056689, 39.45149431260311, 15.265065224691435, 18.56314453103865, 19.48993065069676, 32.31605195488036, 11.374637246243823, 20.719567350986896, 20.50324741600919, 15.942993614223013, 25.226966008165522, 19.939682605048876, 21.32976204884988, 20.295875714929625, 18.863382900124986, 24.34861831253574, 24.425343947024842, 27.484109525964193, 41.853672953098595, 43.74518956371233, 15.253375735335647, 26.437351682047925, 13.830575803674767, 30.85200512497349, 19.455373582617987, 25.01318280522918, 24.88029511190593, 18.9659126247741, 13.361687538477161, 42.615612112860425, 49.548038681675095, 23.410397736906372, 25.15542859397714, 12.526689740725828, 46.5692598296103, 18.43279335121059, 15.400204203460447, 20.23152184623182, 9.535471650105965, 19.414499498352942, 16.54798266827169, 15.961840567488832, 27.98490403763845, 24.645754095886534, 31.108649309779597, 10.31246800290689, 38.14587812882758, 24.009109459855267, 10.40370347916893, 10.365975924720953, 44.45504493621803, 31.93287948353104, 17.043630939120135, 25.156870933612076, 17.87773701480244, 26.54615511293424, 21.032127364114977, 11.583823947829023, 48.86116329277093, 16.85568791820813, 22.27209619233265, 23.168516426657312, 17.607519119791718, 17.10872265480015, 20.974380799318563, 15.433509990032093, 11.256443323063525, 24.897844062550668, 35.31360602502965, 28.550586526314138, 16.004667499298463, 22.66517018656603, 26.14831131782035, 18.002262492034443, 30.821792973358434, 7.551590274166486, 14.167993486002153, 12.328191579568385, 17.865792911701895, 18.8565897533606, 20.155976666382397, 20.087472233334893, 14.194419191579033, 26.11334257787488, 17.757888275774732, 18.68891017252448, 20.46699485794454, 21.804942367804273, 21.928863724108417, 24.52420197283668, 22.2853894914772, 21.45580073077518]}

Next Steps

Pipeline Serialization

You’d perhaps want to learn more on how our pipelines are serialized to avoid retraining every time. For instance, if you use a TensorFlow, PyTorch, or Keras step within your pipeline, you’ll need to use different serializers for those steps. Overall, it’s possible to serialize a whole pipeline, by creating a subfolder tree, provided a default serializer, and the possibility for each step to define their own serializers when needed.

Data Transformation Caching

It’s possible to enable pipelines to always save and cache the data they process such that it can skip the computations and return the result directly if it sees data like this again. The whole is flexible such that you can also allow caching just for some specific steps, and by creating your own caching class, you can then use any database and cache item limit as you need.


Checkpoints are a way to combine both pipeline serialization and data transformation caching at once, and automatically, just by using a special pipeline type and defining where you checkpoint within the pipeline.