tfcompile
Deploying a TensorFlow graph via XLA AOT compilation¶
Many machine learning models are deployed as cloud services where you can accommodate a full-blown runtime, but managing servers and requiring internet connectivity for your app is a hassle. Instead, you can use tfcompile (a XLA CLI tool) to compile a TensorFlow graph to executable machine code, and then deploy that as a microservice or native application.
XLA¶
XLA is a compiler of TensorFlow graphs.
- TensorFlow's graph abstraction incurs overhead.
- XLA combats this so we can afford typing high-level code without relying on the existence of custom ops kernels.
- The compiler can be used for graph optimization during model training, but we'll focus on ahead-of-time (AOT) compilation for model deployment.
- Implementation is still maturing. XLA was released march last year and there are several commits per day.


Steps for ahead-of-time compiling a graph with XLA¶
We'll use the command-line tool tfcompile via Bazel.
- Configure the subgraph to compile.
- Use the tf_library build macro to compile the subgraph.
- Write code to invoke the subgraph.
- Create the final binary.
Step 0: Model¶
Before we start compiling a graph we need to build our graph. Let's keep it simple by just loading a pretrained image classifier.
# This cell can be safely removed and doesn't need to be run.
%env CUDA_VISIBLE_DEVICES=''
import tensorflow as tf
env: CUDA_VISIBLE_DEVICES=''
import tensorflow as tf
tf.keras.backend.set_learning_phase(False)
model = tf.keras.applications.ResNet50()
model.summary(80)
________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
================================================================================
input_1 (InputLayer) (None, 224, 224, 0
________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 9472 input_1[0][0]
________________________________________________________________________________
bn_conv1 (BatchNormalizat (None, 112, 112, 256 conv1[0][0]
________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 0 bn_conv1[0][0]
________________________________________________________________________________
max_pooling2d_1 (MaxPooli (None, 55, 55, 64 0 activation_1[0][0]
________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 55, 55, 64 4160 max_pooling2d_1[0][0]
________________________________________________________________________________
bn2a_branch2a (BatchNorma (None, 55, 55, 64 256 res2a_branch2a[0][0]
________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64 0 bn2a_branch2a[0][0]
________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_2[0][0]
________________________________________________________________________________
bn2a_branch2b (BatchNorma (None, 55, 55, 64 256 res2a_branch2b[0][0]
________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64 0 bn2a_branch2b[0][0]
________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_3[0][0]
________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 55, 55, 25 16640 max_pooling2d_1[0][0]
________________________________________________________________________________
bn2a_branch2c (BatchNorma (None, 55, 55, 25 1024 res2a_branch2c[0][0]
________________________________________________________________________________
bn2a_branch1 (BatchNormal (None, 55, 55, 25 1024 res2a_branch1[0][0]
________________________________________________________________________________
add_1 (Add) (None, 55, 55, 25 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 25 0 add_1[0][0]
________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 55, 55, 64 16448 activation_4[0][0]
________________________________________________________________________________
bn2b_branch2a (BatchNorma (None, 55, 55, 64 256 res2b_branch2a[0][0]
________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64 0 bn2b_branch2a[0][0]
________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_5[0][0]
________________________________________________________________________________
bn2b_branch2b (BatchNorma (None, 55, 55, 64 256 res2b_branch2b[0][0]
________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64 0 bn2b_branch2b[0][0]
________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_6[0][0]
________________________________________________________________________________
bn2b_branch2c (BatchNorma (None, 55, 55, 25 1024 res2b_branch2c[0][0]
________________________________________________________________________________
add_2 (Add) (None, 55, 55, 25 0 bn2b_branch2c[0][0]
activation_4[0][0]
________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 25 0 add_2[0][0]
________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 55, 55, 64 16448 activation_7[0][0]
________________________________________________________________________________
bn2c_branch2a (BatchNorma (None, 55, 55, 64 256 res2c_branch2a[0][0]
________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64 0 bn2c_branch2a[0][0]
________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_8[0][0]
________________________________________________________________________________
bn2c_branch2b (BatchNorma (None, 55, 55, 64 256 res2c_branch2b[0][0]
________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64 0 bn2c_branch2b[0][0]
________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_9[0][0]
________________________________________________________________________________
bn2c_branch2c (BatchNorma (None, 55, 55, 25 1024 res2c_branch2c[0][0]
________________________________________________________________________________
add_3 (Add) (None, 55, 55, 25 0 bn2c_branch2c[0][0]
activation_7[0][0]
________________________________________________________________________________
activation_10 (Activation (None, 55, 55, 25 0 add_3[0][0]
________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 12 32896 activation_10[0][0]
________________________________________________________________________________
bn3a_branch2a (BatchNorma (None, 28, 28, 12 512 res3a_branch2a[0][0]
________________________________________________________________________________
activation_11 (Activation (None, 28, 28, 12 0 bn3a_branch2a[0][0]
________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_11[0][0]
________________________________________________________________________________
bn3a_branch2b (BatchNorma (None, 28, 28, 12 512 res3a_branch2b[0][0]
________________________________________________________________________________
activation_12 (Activation (None, 28, 28, 12 0 bn3a_branch2b[0][0]
________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_12[0][0]
________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 51 131584 activation_10[0][0]
________________________________________________________________________________
bn3a_branch2c (BatchNorma (None, 28, 28, 51 2048 res3a_branch2c[0][0]
________________________________________________________________________________
bn3a_branch1 (BatchNormal (None, 28, 28, 51 2048 res3a_branch1[0][0]
________________________________________________________________________________
add_4 (Add) (None, 28, 28, 51 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
________________________________________________________________________________
activation_13 (Activation (None, 28, 28, 51 0 add_4[0][0]
________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_13[0][0]
________________________________________________________________________________
bn3b_branch2a (BatchNorma (None, 28, 28, 12 512 res3b_branch2a[0][0]
________________________________________________________________________________
activation_14 (Activation (None, 28, 28, 12 0 bn3b_branch2a[0][0]
________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_14[0][0]
________________________________________________________________________________
bn3b_branch2b (BatchNorma (None, 28, 28, 12 512 res3b_branch2b[0][0]
________________________________________________________________________________
activation_15 (Activation (None, 28, 28, 12 0 bn3b_branch2b[0][0]
________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_15[0][0]
________________________________________________________________________________
bn3b_branch2c (BatchNorma (None, 28, 28, 51 2048 res3b_branch2c[0][0]
________________________________________________________________________________
add_5 (Add) (None, 28, 28, 51 0 bn3b_branch2c[0][0]
activation_13[0][0]
________________________________________________________________________________
activation_16 (Activation (None, 28, 28, 51 0 add_5[0][0]
________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_16[0][0]
________________________________________________________________________________
bn3c_branch2a (BatchNorma (None, 28, 28, 12 512 res3c_branch2a[0][0]
________________________________________________________________________________
activation_17 (Activation (None, 28, 28, 12 0 bn3c_branch2a[0][0]
________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_17[0][0]
________________________________________________________________________________
bn3c_branch2b (BatchNorma (None, 28, 28, 12 512 res3c_branch2b[0][0]
________________________________________________________________________________
activation_18 (Activation (None, 28, 28, 12 0 bn3c_branch2b[0][0]
________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_18[0][0]
________________________________________________________________________________
bn3c_branch2c (BatchNorma (None, 28, 28, 51 2048 res3c_branch2c[0][0]
________________________________________________________________________________
add_6 (Add) (None, 28, 28, 51 0 bn3c_branch2c[0][0]
activation_16[0][0]
________________________________________________________________________________
activation_19 (Activation (None, 28, 28, 51 0 add_6[0][0]
________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_19[0][0]
________________________________________________________________________________
bn3d_branch2a (BatchNorma (None, 28, 28, 12 512 res3d_branch2a[0][0]
________________________________________________________________________________
activation_20 (Activation (None, 28, 28, 12 0 bn3d_branch2a[0][0]
________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_20[0][0]
________________________________________________________________________________
bn3d_branch2b (BatchNorma (None, 28, 28, 12 512 res3d_branch2b[0][0]
________________________________________________________________________________
activation_21 (Activation (None, 28, 28, 12 0 bn3d_branch2b[0][0]
________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_21[0][0]
________________________________________________________________________________
bn3d_branch2c (BatchNorma (None, 28, 28, 51 2048 res3d_branch2c[0][0]
________________________________________________________________________________
add_7 (Add) (None, 28, 28, 51 0 bn3d_branch2c[0][0]
activation_19[0][0]
________________________________________________________________________________
activation_22 (Activation (None, 28, 28, 51 0 add_7[0][0]
________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 14, 14, 25 131328 activation_22[0][0]
________________________________________________________________________________
bn4a_branch2a (BatchNorma (None, 14, 14, 25 1024 res4a_branch2a[0][0]
________________________________________________________________________________
activation_23 (Activation (None, 14, 14, 25 0 bn4a_branch2a[0][0]
________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_23[0][0]
________________________________________________________________________________
bn4a_branch2b (BatchNorma (None, 14, 14, 25 1024 res4a_branch2b[0][0]
________________________________________________________________________________
activation_24 (Activation (None, 14, 14, 25 0 bn4a_branch2b[0][0]
________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_24[0][0]
________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 14, 14, 10 525312 activation_22[0][0]
________________________________________________________________________________
bn4a_branch2c (BatchNorma (None, 14, 14, 10 4096 res4a_branch2c[0][0]
________________________________________________________________________________
bn4a_branch1 (BatchNormal (None, 14, 14, 10 4096 res4a_branch1[0][0]
________________________________________________________________________________
add_8 (Add) (None, 14, 14, 10 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
________________________________________________________________________________
activation_25 (Activation (None, 14, 14, 10 0 add_8[0][0]
________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_25[0][0]
________________________________________________________________________________
bn4b_branch2a (BatchNorma (None, 14, 14, 25 1024 res4b_branch2a[0][0]
________________________________________________________________________________
activation_26 (Activation (None, 14, 14, 25 0 bn4b_branch2a[0][0]
________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_26[0][0]
________________________________________________________________________________
bn4b_branch2b (BatchNorma (None, 14, 14, 25 1024 res4b_branch2b[0][0]
________________________________________________________________________________
activation_27 (Activation (None, 14, 14, 25 0 bn4b_branch2b[0][0]
________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_27[0][0]
________________________________________________________________________________
bn4b_branch2c (BatchNorma (None, 14, 14, 10 4096 res4b_branch2c[0][0]
________________________________________________________________________________
add_9 (Add) (None, 14, 14, 10 0 bn4b_branch2c[0][0]
activation_25[0][0]
________________________________________________________________________________
activation_28 (Activation (None, 14, 14, 10 0 add_9[0][0]
________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_28[0][0]
________________________________________________________________________________
bn4c_branch2a (BatchNorma (None, 14, 14, 25 1024 res4c_branch2a[0][0]
________________________________________________________________________________
activation_29 (Activation (None, 14, 14, 25 0 bn4c_branch2a[0][0]
________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_29[0][0]
________________________________________________________________________________
bn4c_branch2b (BatchNorma (None, 14, 14, 25 1024 res4c_branch2b[0][0]
________________________________________________________________________________
activation_30 (Activation (None, 14, 14, 25 0 bn4c_branch2b[0][0]
________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_30[0][0]
________________________________________________________________________________
bn4c_branch2c (BatchNorma (None, 14, 14, 10 4096 res4c_branch2c[0][0]
________________________________________________________________________________
add_10 (Add) (None, 14, 14, 10 0 bn4c_branch2c[0][0]
activation_28[0][0]
________________________________________________________________________________
activation_31 (Activation (None, 14, 14, 10 0 add_10[0][0]
________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_31[0][0]
________________________________________________________________________________
bn4d_branch2a (BatchNorma (None, 14, 14, 25 1024 res4d_branch2a[0][0]
________________________________________________________________________________
activation_32 (Activation (None, 14, 14, 25 0 bn4d_branch2a[0][0]
________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_32[0][0]
________________________________________________________________________________
bn4d_branch2b (BatchNorma (None, 14, 14, 25 1024 res4d_branch2b[0][0]
________________________________________________________________________________
activation_33 (Activation (None, 14, 14, 25 0 bn4d_branch2b[0][0]
________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_33[0][0]
________________________________________________________________________________
bn4d_branch2c (BatchNorma (None, 14, 14, 10 4096 res4d_branch2c[0][0]
________________________________________________________________________________
add_11 (Add) (None, 14, 14, 10 0 bn4d_branch2c[0][0]
activation_31[0][0]
________________________________________________________________________________
activation_34 (Activation (None, 14, 14, 10 0 add_11[0][0]
________________________________________________________________________________
res4e_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_34[0][0]
________________________________________________________________________________
bn4e_branch2a (BatchNorma (None, 14, 14, 25 1024 res4e_branch2a[0][0]
________________________________________________________________________________
activation_35 (Activation (None, 14, 14, 25 0 bn4e_branch2a[0][0]
________________________________________________________________________________
res4e_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_35[0][0]
________________________________________________________________________________
bn4e_branch2b (BatchNorma (None, 14, 14, 25 1024 res4e_branch2b[0][0]
________________________________________________________________________________
activation_36 (Activation (None, 14, 14, 25 0 bn4e_branch2b[0][0]
________________________________________________________________________________
res4e_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_36[0][0]
________________________________________________________________________________
bn4e_branch2c (BatchNorma (None, 14, 14, 10 4096 res4e_branch2c[0][0]
________________________________________________________________________________
add_12 (Add) (None, 14, 14, 10 0 bn4e_branch2c[0][0]
activation_34[0][0]
________________________________________________________________________________
activation_37 (Activation (None, 14, 14, 10 0 add_12[0][0]
________________________________________________________________________________
res4f_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_37[0][0]
________________________________________________________________________________
bn4f_branch2a (BatchNorma (None, 14, 14, 25 1024 res4f_branch2a[0][0]
________________________________________________________________________________
activation_38 (Activation (None, 14, 14, 25 0 bn4f_branch2a[0][0]
________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_38[0][0]
________________________________________________________________________________
bn4f_branch2b (BatchNorma (None, 14, 14, 25 1024 res4f_branch2b[0][0]
________________________________________________________________________________
activation_39 (Activation (None, 14, 14, 25 0 bn4f_branch2b[0][0]
________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_39[0][0]
________________________________________________________________________________
bn4f_branch2c (BatchNorma (None, 14, 14, 10 4096 res4f_branch2c[0][0]
________________________________________________________________________________
add_13 (Add) (None, 14, 14, 10 0 bn4f_branch2c[0][0]
activation_37[0][0]
________________________________________________________________________________
activation_40 (Activation (None, 14, 14, 10 0 add_13[0][0]
________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
________________________________________________________________________________
bn5a_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5a_branch2a[0][0]
________________________________________________________________________________
activation_41 (Activation (None, 7, 7, 512) 0 bn5a_branch2a[0][0]
________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
________________________________________________________________________________
bn5a_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5a_branch2b[0][0]
________________________________________________________________________________
activation_42 (Activation (None, 7, 7, 512) 0 bn5a_branch2b[0][0]
________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_42[0][0]
________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 7, 7, 2048 2099200 activation_40[0][0]
________________________________________________________________________________
bn5a_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5a_branch2c[0][0]
________________________________________________________________________________
bn5a_branch1 (BatchNormal (None, 7, 7, 2048 8192 res5a_branch1[0][0]
________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
________________________________________________________________________________
activation_43 (Activation (None, 7, 7, 2048 0 add_14[0][0]
________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
________________________________________________________________________________
bn5b_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5b_branch2a[0][0]
________________________________________________________________________________
activation_44 (Activation (None, 7, 7, 512) 0 bn5b_branch2a[0][0]
________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
________________________________________________________________________________
bn5b_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5b_branch2b[0][0]
________________________________________________________________________________
activation_45 (Activation (None, 7, 7, 512) 0 bn5b_branch2b[0][0]
________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_45[0][0]
________________________________________________________________________________
bn5b_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5b_branch2c[0][0]
________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048 0 bn5b_branch2c[0][0]
activation_43[0][0]
________________________________________________________________________________
activation_46 (Activation (None, 7, 7, 2048 0 add_15[0][0]
________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
________________________________________________________________________________
bn5c_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5c_branch2a[0][0]
________________________________________________________________________________
activation_47 (Activation (None, 7, 7, 512) 0 bn5c_branch2a[0][0]
________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
________________________________________________________________________________
bn5c_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5c_branch2b[0][0]
________________________________________________________________________________
activation_48 (Activation (None, 7, 7, 512) 0 bn5c_branch2b[0][0]
________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_48[0][0]
________________________________________________________________________________
bn5c_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5c_branch2c[0][0]
________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048 0 bn5c_branch2c[0][0]
activation_46[0][0]
________________________________________________________________________________
activation_49 (Activation (None, 7, 7, 2048 0 add_16[0][0]
________________________________________________________________________________
avg_pool (AveragePooling2 (None, 1, 1, 2048 0 activation_49[0][0]
________________________________________________________________________________
flatten_1 (Flatten) (None, 2048) 0 avg_pool[0][0]
________________________________________________________________________________
fc1000 (Dense) (None, 1000) 2049000 flatten_1[0][0]
================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
________________________________________________________________________________
Step 0.5: Download tfcompile¶
XLA is still maturing and as of now we have to checkout the development release. System prerequisites are git, the build tool Bazel and the Protocol Buffers compiler. I'm also assuming we're running tf-nightly which can be installed via pip.
%rm -rf /tmp/tensorflow
%cd /tmp
!git clone --depth=1 --single-branch https://github.com/tensorflow/tensorflow
%cd tensorflow
!yes "" | ./configure
!protoc tensorflow/compiler/tf2xla/tf2xla.proto --python_out=.
!cp tensorflow/compiler/tf2xla/tf2xla_pb2.py .
/tmp Cloning into 'tensorflow'... remote: Counting objects: 10580, done. remote: Compressing objects: 100% (8825/8825), done. remote: Total 10580 (delta 3329), reused 3594 (delta 1486), pack-reused 0 Receiving objects: 100% (10580/10580), 21.65 MiB | 4.71 MiB/s, done. Resolving deltas: 100% (3329/3329), done. /tmp/tensorflow WARNING: Running Bazel server needs to be killed, because the startup options are different. You have bazel 0.8.1 installed. Please specify the location of python. [Default is /home/carl/anaconda3/bin/python]: Found possible Python library paths: /home/carl/anaconda3/lib/python3.6/site-packages Please input the desired Python library path to use. Default is [/home/carl/anaconda3/lib/python3.6/site-packages] Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: jemalloc as malloc support will be enabled for TensorFlow. Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: Google Cloud Platform support will be enabled for TensorFlow. Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: Hadoop File System support will be enabled for TensorFlow. Do you wish to build TensorFlow with Amazon S3 File System support? [Y/n]: Amazon S3 File System support will be enabled for TensorFlow. Do you wish to build TensorFlow with XLA JIT support? [y/N]: No XLA JIT support will be enabled for TensorFlow. Do you wish to build TensorFlow with GDR support? [y/N]: No GDR support will be enabled for TensorFlow. Do you wish to build TensorFlow with VERBS support? [y/N]: No VERBS support will be enabled for TensorFlow. Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: No OpenCL SYCL support will be enabled for TensorFlow. Do you wish to build TensorFlow with CUDA support? [y/N]: No CUDA support will be enabled for TensorFlow. Do you wish to build TensorFlow with MPI support? [y/N]: No MPI support will be enabled for TensorFlow. Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: Not configuring the WORKSPACE for Android builds. Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details. --config=mkl # Build with MKL support. --config=monolithic # Config for mostly static monolithic build. Configuration finished yes: standard output: Broken pipe
Step 1: Configure the subgraph to compile.¶
List feeds and fetches¶
tfcompile needs static input shapes so we have to pick a batch size for our image classifier.
import tf2xla_pb2
config = tf2xla_pb2.Config()
batch_size = 1
for x in model.inputs:
x.set_shape([batch_size] + list(x.shape)[1:])
feed = config.feed.add()
feed.id.node_name = x.op.name
feed.shape.MergeFrom(x.shape.as_proto())
for x in model.outputs:
fetch = config.fetch.add()
fetch.id.node_name = x.op.name
with open('graph.config.pbtxt', 'w') as f:
f.write(str(config))
cat graph.config.pbtxt
feed {
id {
node_name: "input_1"
}
shape {
dim {
size: 1
}
dim {
size: 224
}
dim {
size: 224
}
dim {
size: 3
}
}
}
fetch {
id {
node_name: "fc1000/Softmax"
}
}
Freeze graph¶
The graph contains mutable nodes that have to be constants. It's possible to let tfcompile handle this for you (via freeze_graph.py) by providing a weights checkpoint along with the graph definition, but as we already have everything loaded we'll make them into constants right away.
session = tf.keras.backend.get_session()
output_node_names = [node.op.name for node in model.outputs]
graphdef = tf.graph_util.convert_variables_to_constants(session, session.graph_def, output_node_names)
tf.train.write_graph(graphdef, '.', 'graph.pb', as_text=False)
INFO:tensorflow:Froze 320 variables. Converted 320 variables to const ops.
'./graph.pb'
Step 2: Use the tf_library build macro to compile the subgraph.¶
%%writefile BUILD
load('@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl', 'tf_library')
tf_library(
name = 'graph',
config = 'graph.config.pbtxt',
cpp_class = 'Graph',
graph = 'graph.pb',
)
Overwriting BUILD
!bazel build --show_progress_rate_limit=600 @org_tensorflow//:graph
....... Loading: Loading: 0 packages loaded WARNING: /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30 Analyzing: target @org_tensorflow//:graph (68 packages loaded) INFO: Analysed target @org_tensorflow//:graph (74 packages loaded). Building: no action running INFO: Found 1 target... Building: no action running [0 / 6] BazelWorkspaceStatusAction stable-status.txt INFO: From Executing genrule @org_tensorflow//tensorflow/core:version_info_gen [for host]: [1,674 / 3,309] @org_tensorflow//tensorflow/core:version_info_gen; 0s local fatal: No names found, cannot describe anything. [1,674 / 3,309] @org_tensorflow//tensorflow/core:version_info_gen; 0s local INFO: From Executing genrule @org_tensorflow//:gen_graph: [3,332 / 3,336] Executing genrule @org_tensorflow//:gen_graph; 47s local 2018-01-11 15:27:20.408071: I external/org_tensorflow/tensorflow/core/platform/s3/aws_logging.cc:53] Initializing Curl library 2018-01-11 15:27:20.514752: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA [3,332 / 3,336] Executing genrule @org_tensorflow//:gen_graph; 47s local Target @org_tensorflow//:graph up-to-date: [3,336 / 3,336] no action running bazel-bin/external/org_tensorflow/libgraph.a [3,336 / 3,336] no action running bazel-bin/external/org_tensorflow/libgraph.pic.a [3,336 / 3,336] no action running bazel-bin/external/org_tensorflow/libgraph.so [3,336 / 3,336] no action running INFO: Elapsed time: 57.837s, Critical Path: 50.33s [3,336 / 3,336] no action running INFO: Build completed successfully, 3 total actions
cat bazel-genfiles/graph.h
// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
//
// This header was generated via ahead-of-time compilation of a TensorFlow
// graph. An object file corresponding to this header was also generated.
// This header gives access to the functionality in that object file.
//
// clang-format off
#ifndef TFCOMPILE_GENERATED_____graph_H_ // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_____graph_H_ // NOLINT(build/header_guard)
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void ____graph(
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps, tensorflow::int64* profile_counters);
// Graph represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code. This extends the generic
// XlaCompiledCpuFunction class with statically type-safe arg and result
// methods. Usage example:
//
// Graph computation;
// // ...set args using computation.argN methods
// CHECK(computation.Run());
// // ...inspect results using computation.resultN methods
//
// The Run method invokes the actual computation, with inputs read from arg
// buffers, and outputs written to result buffers. Each Run call may also use
// a set of temporary buffers for the computation.
//
// By default each instance of this class manages its own arg, result and temp
// buffers. The AllocMode constructor parameter may be used to modify the
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
// o Calls to non-const methods require exclusive access to the object.
// o Concurrent calls to const methods are OK, if those calls are made while it
// is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
// (arg0: f32[1,224,224,3]) -> (f32[1,1000])
//
// Memory stats:
// arg bytes total: 602112
// arg bytes aligned: 602112
// temp bytes total: 17815208
// temp bytes aligned: 17815232
class Graph : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 1;
// Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() {
static constexpr intptr_t kArgSizes[kNumArgs] = {602112};
return kArgSizes;
}
// Returns static data used to create an XlaCompiledCpuFunction.
static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
data->raw_function = ____graph;
data->arg_sizes = ArgSizes();
data->num_args = kNumArgs;
data->temp_sizes = TempSizes();
data->num_temps = kNumTemps;
data->result_index = kResultIndex;
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
return data;
}();
return *kStaticData;
}
Graph(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
Graph(const Graph&) = delete;
Graph& operator=(const Graph&) = delete;
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument, with the following
// general form:
//
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
// argument, to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
//
// T& argN(...dim indices...)
// Returns a reference to the value of type T for positional argument N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
void set_arg0_data(void* data) {
set_arg_data(0, data);
}
float* arg0_data() {
return static_cast<float*>(arg_data(0));
}
float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) {
return (*static_cast<float(*)[1][224][224][3]>(
arg_data(0)))[dim0][dim1][dim2][dim3];
}
const float* arg0_data() const {
return static_cast<const float*>(arg_data(0));
}
const float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) const {
return (*static_cast<const float(*)[1][224][224][3]>(
arg_data(0)))[dim0][dim1][dim2][dim3];
}
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
// for each positional result, with the following general form:
//
// T* resultN_data()
// Returns the buffer of type T for positional result N.
//
// T& resultN(...dim indices...)
// Returns a reference to the value of type T for positional result N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// Unlike the arg methods, there is no set_resultN_data method. The result
// buffers are managed internally, and may change after each call to Run.
float* result0_data() {
return static_cast<float*>(result_data(0));
}
float& result0(size_t dim0, size_t dim1) {
return (*static_cast<float(*)[1][1000]>(
result_data(0)))[dim0][dim1];
}
const float* result0_data() const {
return static_cast<const float*>(result_data(0));
}
const float& result0(size_t dim0, size_t dim1) const {
return (*static_cast<const float(*)[1][1000]>(
result_data(0)))[dim0][dim1];
}
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = 10;
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 2;
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {-1, 4000, 8, -1, -1, -1, -1, -1, -1, 17811200};
return kTempSizes;
}
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
return nullptr;
}
// Array of names of each positional result, terminated by nullptr.
static const char** StaticResultNames() {
return nullptr;
}
// Shape of the args and results.
static const xla::ProgramShape* StaticProgramShape() {
return nullptr;
}
};
#endif // TFCOMPILE_GENERATED_____graph_H_
// clang-format on
Step 3: Write code to invoke the subgraph.¶
%%writefile graph.cc
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include "graph.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
extern "C" int run(float *input, float *output, int input_size, int output_size) {
Eigen::ThreadPool tp(std::thread::hardware_concurrency());
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
Graph graph;
graph.set_thread_pool(&device);
std::copy(input, input + input_size, graph.arg0_data());
auto ok = graph.Run();
if (not ok) return -1;
std::copy(graph.result0_data(), graph.result0_data() + output_size, output);
return 0;
}
Writing graph.cc
Step 4: Create the final binary.¶
Instead of calling gcc directly, and as Bazel is already required for building the tfcompile tool, we'll make a cc_binary rule. In fact, we could just have done one big BUILD file directly after having cloned the TensorFlow repo.
%%writefile -a BUILD
cc_binary(
name = "libmodel.so",
srcs = ["graph.cc"],
deps = [":graph", "//third_party/eigen3"],
linkopts = ["-lpthread"],
linkshared = 1,
copts = ["-fPIC"],
)
Appending to BUILD
!bazel build --show_progress_rate_limit=60 @org_tensorflow//:libmodel.so
Loading: Loading: 0 packages loaded WARNING: /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30 Analyzing: target @org_tensorflow//:libmodel.so (2 packages loaded) INFO: Analysed target @org_tensorflow//:libmodel.so (2 packages loaded). Building: no action running INFO: Found 1 target... Building: no action running [0 / 5] BazelWorkspaceStatusAction stable-status.txt Target @org_tensorflow//:libmodel.so up-to-date: [632 / 632] no action running bazel-bin/external/org_tensorflow/libmodel.so [632 / 632] no action running INFO: Elapsed time: 1.852s, Critical Path: 0.56s [632 / 632] no action running INFO: Build completed successfully, 1 total action
import numpy as np
libmodel = np.ctypeslib.load_library('libmodel', 'bazel-bin/external/org_tensorflow')
libmodel.run.argtypes = [
np.ctypeslib.ndpointer(np.float32, ndim=4, shape=(1, 224, 224, 3), flags=('c', 'a')),
np.ctypeslib.ndpointer(np.float32, ndim=2, shape=(1, 1000), flags=('c', 'a', 'w')),
np.ctypeslib.ctypes.c_int,
np.ctypeslib.ctypes.c_int]
def predict(x):
x = np.require(x, np.float32, ('c', 'a'))
y = np.require(np.zeros((1, 1000)), np.float32, ('c', 'a', 'w'))
libmodel.run(x, y, x.size, y.size)
return y
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
image_path = input()
x = image.img_to_array(image.load_img(image_path, target_size=(224, 224)))
x = x[None, ...]
x = preprocess_input(x)
y = predict(x)
decode_predictions(y)[0]
[('n02110806', 'basenji', 0.60816735),
('n02441942', 'weasel', 0.10849755),
('n02091244', 'Ibizan_hound', 0.081580825),
('n02124075', 'Egyptian_cat', 0.044705715),
('n02123597', 'Siamese_cat', 0.025189402)]
%timeit model.predict(x)
%timeit predict(x)
np.testing.assert_allclose(model.predict(x), predict(x), atol=1e-5)
150 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 191 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
model = tf.keras.applications.ResNet50()
model.predict(x)
2.96 s ± 456 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)