# Deploying a TensorFlow graph via XLA AOT compilation
Many machine learning models are deployed as cloud services where you can accommodate a full-blown runtime, but managing servers and requiring internet connectivity for your app is a hassle. Instead, you can use tfcompile (a XLA CLI tool) to compile a TensorFlow graph to executable machine code, and then deploy that as a microservice or native application.

# XLA
[XLA](https://www.tensorflow.org/performance/xla/) is a compiler of TensorFlow graphs.

- TensorFlow's graph abstraction incurs overhead.
- XLA combats this so we can afford typing high-level code without relying on the existence of custom ops kernels.
- The compiler can be used for graph optimization during model training, but we'll focus on ahead-of-time (AOT) compilation for model deployment.
- Implementation is still maturing. XLA was released march last year and there are several commits per day.

![image.png](https://2.bp.blogspot.com/-yhjY3pc6oow/WLRn2z4mPBI/AAAAAAAACcU/t_EAR6QMwQQkTBPftJQEonaB2DMbRXmXwCLcB/s640/Screen%2BShot%2B2017-02-27%2Bat%2B9.54.12%2BAM.png)

![](https://www.tensorflow.org/images/how-does-xla-work.png)

# Steps for ahead-of-time compiling a graph with XLA
We'll use the command-line tool tfcompile via Bazel.
1. Configure the subgraph to compile.
1. Use the tf_library build macro to compile the subgraph.
1. Write code to invoke the subgraph.
1. Create the final binary.

## Step 0: Model
Before we start compiling a graph we need to build our graph. Let's keep it simple by just loading a pretrained image classifier.

In [3]:
# This cell can be safely removed and doesn't need to be run.
%env CUDA_VISIBLE_DEVICES=''
import tensorflow as tf

env: CUDA_VISIBLE_DEVICES=''


In [4]:
import tensorflow as tf

tf.keras.backend.set_learning_phase(False)
model = tf.keras.applications.ResNet50()
model.summary(80)

________________________________________________________________________________
Layer (type)              Output Shape      Param #  Connected to               
input_1 (InputLayer)      (None, 224, 224,  0                                   
________________________________________________________________________________
conv1 (Conv2D)            (None, 112, 112,  9472     input_1[0][0]              
________________________________________________________________________________
bn_conv1 (BatchNormalizat (None, 112, 112,  256      conv1[0][0]                
________________________________________________________________________________
activation_1 (Activation) (None, 112, 112,  0        bn_conv1[0][0]             
________________________________________________________________________________
max_pooling2d_1 (MaxPooli (None, 55, 55, 64 0        activation_1[0][0]         
________________________________________________________________________________
res2a_branch2a (Conv2D)   (N

## Step 0.5: Download tfcompile
XLA is still maturing and as of now we have to checkout the development release. System prerequisites are git, the build tool [Bazel](https://docs.bazel.build) and the [Protocol Buffers](https://developers.google.com/protocol-buffers) compiler. I'm also assuming we're running tf-nightly which can be installed via pip.

In [5]:
%rm -rf /tmp/tensorflow

In [6]:
%cd /tmp
!git clone --depth=1 --single-branch https://github.com/tensorflow/tensorflow
%cd tensorflow
!yes "" | ./configure
!protoc tensorflow/compiler/tf2xla/tf2xla.proto --python_out=.
!cp tensorflow/compiler/tf2xla/tf2xla_pb2.py .

/tmp
Cloning into 'tensorflow'...
remote: Counting objects: 10580, done.[K
remote: Compressing objects: 100% (8825/8825), done.[K
remote: Total 10580 (delta 3329), reused 3594 (delta 1486), pack-reused 0[K
Receiving objects: 100% (10580/10580), 21.65 MiB | 4.71 MiB/s, done.
Resolving deltas: 100% (3329/3329), done.
/tmp/tensorflow
You have bazel 0.8.1 installed.
Please specify the location of python. [Default is /home/carl/anaconda3/bin/python]: 

Found possible Python library paths:
  /home/carl/anaconda3/lib/python3.6/site-packages
Please input the desired Python library path to use.  Default is [/home/carl/anaconda3/lib/python3.6/site-packages]
Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: jemalloc as malloc support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: Google Cloud Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Hadoop File System support? [Y/n]:

## Step 1: Configure the subgraph to compile.

### List feeds and fetches
tfcompile needs static input shapes so we have to pick a batch size for our image classifier.

In [7]:
import tf2xla_pb2

config = tf2xla_pb2.Config()

batch_size = 1

for x in model.inputs:
    x.set_shape([batch_size] + list(x.shape)[1:])
    feed = config.feed.add()
    feed.id.node_name = x.op.name
    feed.shape.MergeFrom(x.shape.as_proto())

for x in model.outputs:
    fetch = config.fetch.add()
    fetch.id.node_name = x.op.name

with open('graph.config.pbtxt', 'w') as f:
    f.write(str(config))

In [8]:
cat graph.config.pbtxt

feed {
  id {
    node_name: "input_1"
  }
  shape {
    dim {
      size: 1
    }
    dim {
      size: 224
    }
    dim {
      size: 224
    }
    dim {
      size: 3
    }
  }
}
fetch {
  id {
    node_name: "fc1000/Softmax"
  }
}


### Freeze graph
The graph contains mutable nodes that have to be constants. It's possible to let tfcompile handle this for you (via [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) by providing a weights checkpoint along with the graph definition, but as we already have everything loaded we'll make them into constants right away.

In [10]:
session = tf.keras.backend.get_session()
output_node_names = [node.op.name for node in model.outputs]
graphdef = tf.graph_util.convert_variables_to_constants(session, session.graph_def, output_node_names)
tf.train.write_graph(graphdef, '.', 'graph.pb', as_text=False)

INFO:tensorflow:Froze 320 variables.
Converted 320 variables to const ops.


'./graph.pb'

## Step 2: Use the tf_library build macro to compile the subgraph.

In [11]:
%%writefile BUILD

load('@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl', 'tf_library')

tf_library(
    name = 'graph',
    config = 'graph.config.pbtxt',
    cpp_class = 'Graph',
    graph = 'graph.pb',
)

Overwriting BUILD


In [12]:
!bazel build --show_progress_rate_limit=600 @org_tensorflow//:graph

.......
[32mLoading:[0m 
[1A[K[32mLoading:[0m 0 packages loaded
[32mAnalyzing:[0m target @org_tensorflow//:graph (68 packages loaded)
[1A[K[32mINFO: [0mAnalysed target @org_tensorflow//:graph (74 packages loaded).
[32mBuilding:[0m no action running
[1A[K[32mINFO: [0mFound 1 target...
[32mBuilding:[0m no action running
[1A[K[32m[0 / 6][0m BazelWorkspaceStatusAction stable-status.txt
[1A[K[32mINFO: [0mFrom Executing genrule @org_tensorflow//tensorflow/core:version_info_gen [for host]:
[32m[1,674 / 3,309][0m @org_tensorflow//tensorflow/core:version_info_gen; 0s local
[1A[Kfatal: No names found, cannot describe anything.
[32m[1,674 / 3,309][0m @org_tensorflow//tensorflow/core:version_info_gen; 0s local
[1A[K[32mINFO: [0mFrom Executing genrule @org_tensorflow//:gen_graph:
[32m[3,332 / 3,336][0m Executing genrule @org_tensorflow//:gen_graph; 47s local
[1A[K2018-01-11 15:27:20.408071: I external/org_tensorflow/tensorflow/core/platform/s3/aws_logging.c

In [13]:
cat bazel-genfiles/graph.h

// Generated by tfcompile, the TensorFlow graph compiler.  DO NOT EDIT!
//
// This header was generated via ahead-of-time compilation of a TensorFlow
// graph.  An object file corresponding to this header was also generated.
// This header gives access to the functionality in that object file.
//
// clang-format off

#ifndef TFCOMPILE_GENERATED_____graph_H_  // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_____graph_H_  // NOLINT(build/header_guard)


#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"

namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; }

// (Implementation detail) Entry point to the function in the object file.
extern "C" void ____graph(
    void* result, const xla::ExecutableRunOptions* run_options,
    const void** args, void** temps, tensorflow::int64* profile_counters);


// Graph represents a computation previously specified in a
// T

## Step 3: Write code to invoke the subgraph.

In [14]:
%%writefile graph.cc

#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL

#include "graph.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

extern "C" int run(float *input, float *output, int input_size, int output_size) {
  Eigen::ThreadPool tp(std::thread::hardware_concurrency());
  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
  Graph graph;
  graph.set_thread_pool(&device);

  std::copy(input, input + input_size, graph.arg0_data());
  auto ok = graph.Run();
  if (not ok) return -1;
  std::copy(graph.result0_data(), graph.result0_data() + output_size, output);
  return 0;
}

Writing graph.cc


## Step 4: Create the final binary.
Instead of calling `gcc` directly, and as Bazel is already required for building the tfcompile tool, we'll make a `cc_binary` rule. In fact, we could just have done one big BUILD file directly after having cloned the TensorFlow repo.

In [15]:
%%writefile -a BUILD

cc_binary(
    name = "libmodel.so",
    srcs = ["graph.cc"],
    deps = [":graph", "//third_party/eigen3"],
    linkopts = ["-lpthread"],
    linkshared = 1,
    copts = ["-fPIC"],
)

Appending to BUILD


In [16]:
!bazel build --show_progress_rate_limit=60 @org_tensorflow//:libmodel.so

[32mLoading:[0m 
[1A[K[32mLoading:[0m 0 packages loaded
[32mAnalyzing:[0m target @org_tensorflow//:libmodel.so (2 packages loaded)
[1A[K[32mINFO: [0mAnalysed target @org_tensorflow//:libmodel.so (2 packages loaded).
[32mBuilding:[0m no action running
[1A[K[32mINFO: [0mFound 1 target...
[32mBuilding:[0m no action running
[1A[K[32m[0 / 5][0m BazelWorkspaceStatusAction stable-status.txt
[1A[KTarget @org_tensorflow//:libmodel.so up-to-date:
[32m[632 / 632][0m no action running
[1A[K  bazel-bin/external/org_tensorflow/libmodel.so
[32m[632 / 632][0m no action running
[1A[K[32mINFO: [0mElapsed time: 1.852s, Critical Path: 0.56s
[32m[632 / 632][0m no action running
[1A[K[32mINFO:[0m Build completed successfully, 1 total action
[0m

In [19]:
import numpy as np

libmodel = np.ctypeslib.load_library('libmodel', 'bazel-bin/external/org_tensorflow')
libmodel.run.argtypes = [
    np.ctypeslib.ndpointer(np.float32, ndim=4, shape=(1, 224, 224, 3), flags=('c', 'a')),
    np.ctypeslib.ndpointer(np.float32, ndim=2, shape=(1, 1000), flags=('c', 'a', 'w')),
    np.ctypeslib.ctypes.c_int,
    np.ctypeslib.ctypes.c_int]


def predict(x):
    x = np.require(x, np.float32, ('c', 'a'))
    y = np.require(np.zeros((1, 1000)), np.float32, ('c', 'a', 'w'))
    libmodel.run(x, y, x.size, y.size)
    return y

In [20]:
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions

image_path = input()

x = image.img_to_array(image.load_img(image_path, target_size=(224, 224)))
x = x[None, ...]
x = preprocess_input(x)
y = predict(x)
decode_predictions(y)[0]

[('n02110806', 'basenji', 0.60816735),
 ('n02441942', 'weasel', 0.10849755),
 ('n02091244', 'Ibizan_hound', 0.081580825),
 ('n02124075', 'Egyptian_cat', 0.044705715),
 ('n02123597', 'Siamese_cat', 0.025189402)]

In [36]:
%timeit model.predict(x)
%timeit predict(x)
np.testing.assert_allclose(model.predict(x), predict(x), atol=1e-5)

150 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
191 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [37]:
%%timeit
model = tf.keras.applications.ResNet50()
model.predict(x)

2.96 s ± 456 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# References
- https://www.tensorflow.org/performance/xla/tfcompile
- https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html
- https://youtu.be/kAOanJczHA0
- https://youtu.be/2IOPpyyuLkc