Skip to main content

tfcompile

Deploying a TensorFlow graph via XLA AOT compilation

Many machine learning models are deployed as cloud services where you can accommodate a full-blown runtime, but managing servers and requiring internet connectivity for your app is a hassle. Instead, you can use tfcompile (a XLA CLI tool) to compile a TensorFlow graph to executable machine code, and then deploy that as a microservice or native application.

XLA

XLA is a compiler of TensorFlow graphs.

  • TensorFlow's graph abstraction incurs overhead.
  • XLA combats this so we can afford typing high-level code without relying on the existence of custom ops kernels.
  • The compiler can be used for graph optimization during model training, but we'll focus on ahead-of-time (AOT) compilation for model deployment.
  • Implementation is still maturing. XLA was released march last year and there are several commits per day.

image.png

Steps for ahead-of-time compiling a graph with XLA

We'll use the command-line tool tfcompile via Bazel.

  1. Configure the subgraph to compile.
  2. Use the tf_library build macro to compile the subgraph.
  3. Write code to invoke the subgraph.
  4. Create the final binary.

Step 0: Model

Before we start compiling a graph we need to build our graph. Let's keep it simple by just loading a pretrained image classifier.

In [3]:
# This cell can be safely removed and doesn't need to be run.
%env CUDA_VISIBLE_DEVICES=''
import tensorflow as tf
env: CUDA_VISIBLE_DEVICES=''
In [4]:
import tensorflow as tf

tf.keras.backend.set_learning_phase(False)
model = tf.keras.applications.ResNet50()
model.summary(80)
________________________________________________________________________________
Layer (type)              Output Shape      Param #  Connected to               
================================================================================
input_1 (InputLayer)      (None, 224, 224,  0                                   
________________________________________________________________________________
conv1 (Conv2D)            (None, 112, 112,  9472     input_1[0][0]              
________________________________________________________________________________
bn_conv1 (BatchNormalizat (None, 112, 112,  256      conv1[0][0]                
________________________________________________________________________________
activation_1 (Activation) (None, 112, 112,  0        bn_conv1[0][0]             
________________________________________________________________________________
max_pooling2d_1 (MaxPooli (None, 55, 55, 64 0        activation_1[0][0]         
________________________________________________________________________________
res2a_branch2a (Conv2D)   (None, 55, 55, 64 4160     max_pooling2d_1[0][0]      
________________________________________________________________________________
bn2a_branch2a (BatchNorma (None, 55, 55, 64 256      res2a_branch2a[0][0]       
________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64 0        bn2a_branch2a[0][0]        
________________________________________________________________________________
res2a_branch2b (Conv2D)   (None, 55, 55, 64 36928    activation_2[0][0]         
________________________________________________________________________________
bn2a_branch2b (BatchNorma (None, 55, 55, 64 256      res2a_branch2b[0][0]       
________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64 0        bn2a_branch2b[0][0]        
________________________________________________________________________________
res2a_branch2c (Conv2D)   (None, 55, 55, 25 16640    activation_3[0][0]         
________________________________________________________________________________
res2a_branch1 (Conv2D)    (None, 55, 55, 25 16640    max_pooling2d_1[0][0]      
________________________________________________________________________________
bn2a_branch2c (BatchNorma (None, 55, 55, 25 1024     res2a_branch2c[0][0]       
________________________________________________________________________________
bn2a_branch1 (BatchNormal (None, 55, 55, 25 1024     res2a_branch1[0][0]        
________________________________________________________________________________
add_1 (Add)               (None, 55, 55, 25 0        bn2a_branch2c[0][0]        
                                                     bn2a_branch1[0][0]         
________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 25 0        add_1[0][0]                
________________________________________________________________________________
res2b_branch2a (Conv2D)   (None, 55, 55, 64 16448    activation_4[0][0]         
________________________________________________________________________________
bn2b_branch2a (BatchNorma (None, 55, 55, 64 256      res2b_branch2a[0][0]       
________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64 0        bn2b_branch2a[0][0]        
________________________________________________________________________________
res2b_branch2b (Conv2D)   (None, 55, 55, 64 36928    activation_5[0][0]         
________________________________________________________________________________
bn2b_branch2b (BatchNorma (None, 55, 55, 64 256      res2b_branch2b[0][0]       
________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64 0        bn2b_branch2b[0][0]        
________________________________________________________________________________
res2b_branch2c (Conv2D)   (None, 55, 55, 25 16640    activation_6[0][0]         
________________________________________________________________________________
bn2b_branch2c (BatchNorma (None, 55, 55, 25 1024     res2b_branch2c[0][0]       
________________________________________________________________________________
add_2 (Add)               (None, 55, 55, 25 0        bn2b_branch2c[0][0]        
                                                     activation_4[0][0]         
________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 25 0        add_2[0][0]                
________________________________________________________________________________
res2c_branch2a (Conv2D)   (None, 55, 55, 64 16448    activation_7[0][0]         
________________________________________________________________________________
bn2c_branch2a (BatchNorma (None, 55, 55, 64 256      res2c_branch2a[0][0]       
________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64 0        bn2c_branch2a[0][0]        
________________________________________________________________________________
res2c_branch2b (Conv2D)   (None, 55, 55, 64 36928    activation_8[0][0]         
________________________________________________________________________________
bn2c_branch2b (BatchNorma (None, 55, 55, 64 256      res2c_branch2b[0][0]       
________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64 0        bn2c_branch2b[0][0]        
________________________________________________________________________________
res2c_branch2c (Conv2D)   (None, 55, 55, 25 16640    activation_9[0][0]         
________________________________________________________________________________
bn2c_branch2c (BatchNorma (None, 55, 55, 25 1024     res2c_branch2c[0][0]       
________________________________________________________________________________
add_3 (Add)               (None, 55, 55, 25 0        bn2c_branch2c[0][0]        
                                                     activation_7[0][0]         
________________________________________________________________________________
activation_10 (Activation (None, 55, 55, 25 0        add_3[0][0]                
________________________________________________________________________________
res3a_branch2a (Conv2D)   (None, 28, 28, 12 32896    activation_10[0][0]        
________________________________________________________________________________
bn3a_branch2a (BatchNorma (None, 28, 28, 12 512      res3a_branch2a[0][0]       
________________________________________________________________________________
activation_11 (Activation (None, 28, 28, 12 0        bn3a_branch2a[0][0]        
________________________________________________________________________________
res3a_branch2b (Conv2D)   (None, 28, 28, 12 147584   activation_11[0][0]        
________________________________________________________________________________
bn3a_branch2b (BatchNorma (None, 28, 28, 12 512      res3a_branch2b[0][0]       
________________________________________________________________________________
activation_12 (Activation (None, 28, 28, 12 0        bn3a_branch2b[0][0]        
________________________________________________________________________________
res3a_branch2c (Conv2D)   (None, 28, 28, 51 66048    activation_12[0][0]        
________________________________________________________________________________
res3a_branch1 (Conv2D)    (None, 28, 28, 51 131584   activation_10[0][0]        
________________________________________________________________________________
bn3a_branch2c (BatchNorma (None, 28, 28, 51 2048     res3a_branch2c[0][0]       
________________________________________________________________________________
bn3a_branch1 (BatchNormal (None, 28, 28, 51 2048     res3a_branch1[0][0]        
________________________________________________________________________________
add_4 (Add)               (None, 28, 28, 51 0        bn3a_branch2c[0][0]        
                                                     bn3a_branch1[0][0]         
________________________________________________________________________________
activation_13 (Activation (None, 28, 28, 51 0        add_4[0][0]                
________________________________________________________________________________
res3b_branch2a (Conv2D)   (None, 28, 28, 12 65664    activation_13[0][0]        
________________________________________________________________________________
bn3b_branch2a (BatchNorma (None, 28, 28, 12 512      res3b_branch2a[0][0]       
________________________________________________________________________________
activation_14 (Activation (None, 28, 28, 12 0        bn3b_branch2a[0][0]        
________________________________________________________________________________
res3b_branch2b (Conv2D)   (None, 28, 28, 12 147584   activation_14[0][0]        
________________________________________________________________________________
bn3b_branch2b (BatchNorma (None, 28, 28, 12 512      res3b_branch2b[0][0]       
________________________________________________________________________________
activation_15 (Activation (None, 28, 28, 12 0        bn3b_branch2b[0][0]        
________________________________________________________________________________
res3b_branch2c (Conv2D)   (None, 28, 28, 51 66048    activation_15[0][0]        
________________________________________________________________________________
bn3b_branch2c (BatchNorma (None, 28, 28, 51 2048     res3b_branch2c[0][0]       
________________________________________________________________________________
add_5 (Add)               (None, 28, 28, 51 0        bn3b_branch2c[0][0]        
                                                     activation_13[0][0]        
________________________________________________________________________________
activation_16 (Activation (None, 28, 28, 51 0        add_5[0][0]                
________________________________________________________________________________
res3c_branch2a (Conv2D)   (None, 28, 28, 12 65664    activation_16[0][0]        
________________________________________________________________________________
bn3c_branch2a (BatchNorma (None, 28, 28, 12 512      res3c_branch2a[0][0]       
________________________________________________________________________________
activation_17 (Activation (None, 28, 28, 12 0        bn3c_branch2a[0][0]        
________________________________________________________________________________
res3c_branch2b (Conv2D)   (None, 28, 28, 12 147584   activation_17[0][0]        
________________________________________________________________________________
bn3c_branch2b (BatchNorma (None, 28, 28, 12 512      res3c_branch2b[0][0]       
________________________________________________________________________________
activation_18 (Activation (None, 28, 28, 12 0        bn3c_branch2b[0][0]        
________________________________________________________________________________
res3c_branch2c (Conv2D)   (None, 28, 28, 51 66048    activation_18[0][0]        
________________________________________________________________________________
bn3c_branch2c (BatchNorma (None, 28, 28, 51 2048     res3c_branch2c[0][0]       
________________________________________________________________________________
add_6 (Add)               (None, 28, 28, 51 0        bn3c_branch2c[0][0]        
                                                     activation_16[0][0]        
________________________________________________________________________________
activation_19 (Activation (None, 28, 28, 51 0        add_6[0][0]                
________________________________________________________________________________
res3d_branch2a (Conv2D)   (None, 28, 28, 12 65664    activation_19[0][0]        
________________________________________________________________________________
bn3d_branch2a (BatchNorma (None, 28, 28, 12 512      res3d_branch2a[0][0]       
________________________________________________________________________________
activation_20 (Activation (None, 28, 28, 12 0        bn3d_branch2a[0][0]        
________________________________________________________________________________
res3d_branch2b (Conv2D)   (None, 28, 28, 12 147584   activation_20[0][0]        
________________________________________________________________________________
bn3d_branch2b (BatchNorma (None, 28, 28, 12 512      res3d_branch2b[0][0]       
________________________________________________________________________________
activation_21 (Activation (None, 28, 28, 12 0        bn3d_branch2b[0][0]        
________________________________________________________________________________
res3d_branch2c (Conv2D)   (None, 28, 28, 51 66048    activation_21[0][0]        
________________________________________________________________________________
bn3d_branch2c (BatchNorma (None, 28, 28, 51 2048     res3d_branch2c[0][0]       
________________________________________________________________________________
add_7 (Add)               (None, 28, 28, 51 0        bn3d_branch2c[0][0]        
                                                     activation_19[0][0]        
________________________________________________________________________________
activation_22 (Activation (None, 28, 28, 51 0        add_7[0][0]                
________________________________________________________________________________
res4a_branch2a (Conv2D)   (None, 14, 14, 25 131328   activation_22[0][0]        
________________________________________________________________________________
bn4a_branch2a (BatchNorma (None, 14, 14, 25 1024     res4a_branch2a[0][0]       
________________________________________________________________________________
activation_23 (Activation (None, 14, 14, 25 0        bn4a_branch2a[0][0]        
________________________________________________________________________________
res4a_branch2b (Conv2D)   (None, 14, 14, 25 590080   activation_23[0][0]        
________________________________________________________________________________
bn4a_branch2b (BatchNorma (None, 14, 14, 25 1024     res4a_branch2b[0][0]       
________________________________________________________________________________
activation_24 (Activation (None, 14, 14, 25 0        bn4a_branch2b[0][0]        
________________________________________________________________________________
res4a_branch2c (Conv2D)   (None, 14, 14, 10 263168   activation_24[0][0]        
________________________________________________________________________________
res4a_branch1 (Conv2D)    (None, 14, 14, 10 525312   activation_22[0][0]        
________________________________________________________________________________
bn4a_branch2c (BatchNorma (None, 14, 14, 10 4096     res4a_branch2c[0][0]       
________________________________________________________________________________
bn4a_branch1 (BatchNormal (None, 14, 14, 10 4096     res4a_branch1[0][0]        
________________________________________________________________________________
add_8 (Add)               (None, 14, 14, 10 0        bn4a_branch2c[0][0]        
                                                     bn4a_branch1[0][0]         
________________________________________________________________________________
activation_25 (Activation (None, 14, 14, 10 0        add_8[0][0]                
________________________________________________________________________________
res4b_branch2a (Conv2D)   (None, 14, 14, 25 262400   activation_25[0][0]        
________________________________________________________________________________
bn4b_branch2a (BatchNorma (None, 14, 14, 25 1024     res4b_branch2a[0][0]       
________________________________________________________________________________
activation_26 (Activation (None, 14, 14, 25 0        bn4b_branch2a[0][0]        
________________________________________________________________________________
res4b_branch2b (Conv2D)   (None, 14, 14, 25 590080   activation_26[0][0]        
________________________________________________________________________________
bn4b_branch2b (BatchNorma (None, 14, 14, 25 1024     res4b_branch2b[0][0]       
________________________________________________________________________________
activation_27 (Activation (None, 14, 14, 25 0        bn4b_branch2b[0][0]        
________________________________________________________________________________
res4b_branch2c (Conv2D)   (None, 14, 14, 10 263168   activation_27[0][0]        
________________________________________________________________________________
bn4b_branch2c (BatchNorma (None, 14, 14, 10 4096     res4b_branch2c[0][0]       
________________________________________________________________________________
add_9 (Add)               (None, 14, 14, 10 0        bn4b_branch2c[0][0]        
                                                     activation_25[0][0]        
________________________________________________________________________________
activation_28 (Activation (None, 14, 14, 10 0        add_9[0][0]                
________________________________________________________________________________
res4c_branch2a (Conv2D)   (None, 14, 14, 25 262400   activation_28[0][0]        
________________________________________________________________________________
bn4c_branch2a (BatchNorma (None, 14, 14, 25 1024     res4c_branch2a[0][0]       
________________________________________________________________________________
activation_29 (Activation (None, 14, 14, 25 0        bn4c_branch2a[0][0]        
________________________________________________________________________________
res4c_branch2b (Conv2D)   (None, 14, 14, 25 590080   activation_29[0][0]        
________________________________________________________________________________
bn4c_branch2b (BatchNorma (None, 14, 14, 25 1024     res4c_branch2b[0][0]       
________________________________________________________________________________
activation_30 (Activation (None, 14, 14, 25 0        bn4c_branch2b[0][0]        
________________________________________________________________________________
res4c_branch2c (Conv2D)   (None, 14, 14, 10 263168   activation_30[0][0]        
________________________________________________________________________________
bn4c_branch2c (BatchNorma (None, 14, 14, 10 4096     res4c_branch2c[0][0]       
________________________________________________________________________________
add_10 (Add)              (None, 14, 14, 10 0        bn4c_branch2c[0][0]        
                                                     activation_28[0][0]        
________________________________________________________________________________
activation_31 (Activation (None, 14, 14, 10 0        add_10[0][0]               
________________________________________________________________________________
res4d_branch2a (Conv2D)   (None, 14, 14, 25 262400   activation_31[0][0]        
________________________________________________________________________________
bn4d_branch2a (BatchNorma (None, 14, 14, 25 1024     res4d_branch2a[0][0]       
________________________________________________________________________________
activation_32 (Activation (None, 14, 14, 25 0        bn4d_branch2a[0][0]        
________________________________________________________________________________
res4d_branch2b (Conv2D)   (None, 14, 14, 25 590080   activation_32[0][0]        
________________________________________________________________________________
bn4d_branch2b (BatchNorma (None, 14, 14, 25 1024     res4d_branch2b[0][0]       
________________________________________________________________________________
activation_33 (Activation (None, 14, 14, 25 0        bn4d_branch2b[0][0]        
________________________________________________________________________________
res4d_branch2c (Conv2D)   (None, 14, 14, 10 263168   activation_33[0][0]        
________________________________________________________________________________
bn4d_branch2c (BatchNorma (None, 14, 14, 10 4096     res4d_branch2c[0][0]       
________________________________________________________________________________
add_11 (Add)              (None, 14, 14, 10 0        bn4d_branch2c[0][0]        
                                                     activation_31[0][0]        
________________________________________________________________________________
activation_34 (Activation (None, 14, 14, 10 0        add_11[0][0]               
________________________________________________________________________________
res4e_branch2a (Conv2D)   (None, 14, 14, 25 262400   activation_34[0][0]        
________________________________________________________________________________
bn4e_branch2a (BatchNorma (None, 14, 14, 25 1024     res4e_branch2a[0][0]       
________________________________________________________________________________
activation_35 (Activation (None, 14, 14, 25 0        bn4e_branch2a[0][0]        
________________________________________________________________________________
res4e_branch2b (Conv2D)   (None, 14, 14, 25 590080   activation_35[0][0]        
________________________________________________________________________________
bn4e_branch2b (BatchNorma (None, 14, 14, 25 1024     res4e_branch2b[0][0]       
________________________________________________________________________________
activation_36 (Activation (None, 14, 14, 25 0        bn4e_branch2b[0][0]        
________________________________________________________________________________
res4e_branch2c (Conv2D)   (None, 14, 14, 10 263168   activation_36[0][0]        
________________________________________________________________________________
bn4e_branch2c (BatchNorma (None, 14, 14, 10 4096     res4e_branch2c[0][0]       
________________________________________________________________________________
add_12 (Add)              (None, 14, 14, 10 0        bn4e_branch2c[0][0]        
                                                     activation_34[0][0]        
________________________________________________________________________________
activation_37 (Activation (None, 14, 14, 10 0        add_12[0][0]               
________________________________________________________________________________
res4f_branch2a (Conv2D)   (None, 14, 14, 25 262400   activation_37[0][0]        
________________________________________________________________________________
bn4f_branch2a (BatchNorma (None, 14, 14, 25 1024     res4f_branch2a[0][0]       
________________________________________________________________________________
activation_38 (Activation (None, 14, 14, 25 0        bn4f_branch2a[0][0]        
________________________________________________________________________________
res4f_branch2b (Conv2D)   (None, 14, 14, 25 590080   activation_38[0][0]        
________________________________________________________________________________
bn4f_branch2b (BatchNorma (None, 14, 14, 25 1024     res4f_branch2b[0][0]       
________________________________________________________________________________
activation_39 (Activation (None, 14, 14, 25 0        bn4f_branch2b[0][0]        
________________________________________________________________________________
res4f_branch2c (Conv2D)   (None, 14, 14, 10 263168   activation_39[0][0]        
________________________________________________________________________________
bn4f_branch2c (BatchNorma (None, 14, 14, 10 4096     res4f_branch2c[0][0]       
________________________________________________________________________________
add_13 (Add)              (None, 14, 14, 10 0        bn4f_branch2c[0][0]        
                                                     activation_37[0][0]        
________________________________________________________________________________
activation_40 (Activation (None, 14, 14, 10 0        add_13[0][0]               
________________________________________________________________________________
res5a_branch2a (Conv2D)   (None, 7, 7, 512) 524800   activation_40[0][0]        
________________________________________________________________________________
bn5a_branch2a (BatchNorma (None, 7, 7, 512) 2048     res5a_branch2a[0][0]       
________________________________________________________________________________
activation_41 (Activation (None, 7, 7, 512) 0        bn5a_branch2a[0][0]        
________________________________________________________________________________
res5a_branch2b (Conv2D)   (None, 7, 7, 512) 2359808  activation_41[0][0]        
________________________________________________________________________________
bn5a_branch2b (BatchNorma (None, 7, 7, 512) 2048     res5a_branch2b[0][0]       
________________________________________________________________________________
activation_42 (Activation (None, 7, 7, 512) 0        bn5a_branch2b[0][0]        
________________________________________________________________________________
res5a_branch2c (Conv2D)   (None, 7, 7, 2048 1050624  activation_42[0][0]        
________________________________________________________________________________
res5a_branch1 (Conv2D)    (None, 7, 7, 2048 2099200  activation_40[0][0]        
________________________________________________________________________________
bn5a_branch2c (BatchNorma (None, 7, 7, 2048 8192     res5a_branch2c[0][0]       
________________________________________________________________________________
bn5a_branch1 (BatchNormal (None, 7, 7, 2048 8192     res5a_branch1[0][0]        
________________________________________________________________________________
add_14 (Add)              (None, 7, 7, 2048 0        bn5a_branch2c[0][0]        
                                                     bn5a_branch1[0][0]         
________________________________________________________________________________
activation_43 (Activation (None, 7, 7, 2048 0        add_14[0][0]               
________________________________________________________________________________
res5b_branch2a (Conv2D)   (None, 7, 7, 512) 1049088  activation_43[0][0]        
________________________________________________________________________________
bn5b_branch2a (BatchNorma (None, 7, 7, 512) 2048     res5b_branch2a[0][0]       
________________________________________________________________________________
activation_44 (Activation (None, 7, 7, 512) 0        bn5b_branch2a[0][0]        
________________________________________________________________________________
res5b_branch2b (Conv2D)   (None, 7, 7, 512) 2359808  activation_44[0][0]        
________________________________________________________________________________
bn5b_branch2b (BatchNorma (None, 7, 7, 512) 2048     res5b_branch2b[0][0]       
________________________________________________________________________________
activation_45 (Activation (None, 7, 7, 512) 0        bn5b_branch2b[0][0]        
________________________________________________________________________________
res5b_branch2c (Conv2D)   (None, 7, 7, 2048 1050624  activation_45[0][0]        
________________________________________________________________________________
bn5b_branch2c (BatchNorma (None, 7, 7, 2048 8192     res5b_branch2c[0][0]       
________________________________________________________________________________
add_15 (Add)              (None, 7, 7, 2048 0        bn5b_branch2c[0][0]        
                                                     activation_43[0][0]        
________________________________________________________________________________
activation_46 (Activation (None, 7, 7, 2048 0        add_15[0][0]               
________________________________________________________________________________
res5c_branch2a (Conv2D)   (None, 7, 7, 512) 1049088  activation_46[0][0]        
________________________________________________________________________________
bn5c_branch2a (BatchNorma (None, 7, 7, 512) 2048     res5c_branch2a[0][0]       
________________________________________________________________________________
activation_47 (Activation (None, 7, 7, 512) 0        bn5c_branch2a[0][0]        
________________________________________________________________________________
res5c_branch2b (Conv2D)   (None, 7, 7, 512) 2359808  activation_47[0][0]        
________________________________________________________________________________
bn5c_branch2b (BatchNorma (None, 7, 7, 512) 2048     res5c_branch2b[0][0]       
________________________________________________________________________________
activation_48 (Activation (None, 7, 7, 512) 0        bn5c_branch2b[0][0]        
________________________________________________________________________________
res5c_branch2c (Conv2D)   (None, 7, 7, 2048 1050624  activation_48[0][0]        
________________________________________________________________________________
bn5c_branch2c (BatchNorma (None, 7, 7, 2048 8192     res5c_branch2c[0][0]       
________________________________________________________________________________
add_16 (Add)              (None, 7, 7, 2048 0        bn5c_branch2c[0][0]        
                                                     activation_46[0][0]        
________________________________________________________________________________
activation_49 (Activation (None, 7, 7, 2048 0        add_16[0][0]               
________________________________________________________________________________
avg_pool (AveragePooling2 (None, 1, 1, 2048 0        activation_49[0][0]        
________________________________________________________________________________
flatten_1 (Flatten)       (None, 2048)      0        avg_pool[0][0]             
________________________________________________________________________________
fc1000 (Dense)            (None, 1000)      2049000  flatten_1[0][0]            
================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
________________________________________________________________________________

Step 0.5: Download tfcompile

XLA is still maturing and as of now we have to checkout the development release. System prerequisites are git, the build tool Bazel and the Protocol Buffers compiler. I'm also assuming we're running tf-nightly which can be installed via pip.

In [5]:
%rm -rf /tmp/tensorflow
In [6]:
%cd /tmp
!git clone --depth=1 --single-branch https://github.com/tensorflow/tensorflow
%cd tensorflow
!yes "" | ./configure
!protoc tensorflow/compiler/tf2xla/tf2xla.proto --python_out=.
!cp tensorflow/compiler/tf2xla/tf2xla_pb2.py .
/tmp
Cloning into 'tensorflow'...
remote: Counting objects: 10580, done.
remote: Compressing objects: 100% (8825/8825), done.
remote: Total 10580 (delta 3329), reused 3594 (delta 1486), pack-reused 0
Receiving objects: 100% (10580/10580), 21.65 MiB | 4.71 MiB/s, done.
Resolving deltas: 100% (3329/3329), done.
/tmp/tensorflow
WARNING: Running Bazel server needs to be killed, because the startup options are different.
You have bazel 0.8.1 installed.
Please specify the location of python. [Default is /home/carl/anaconda3/bin/python]: 

Found possible Python library paths:
  /home/carl/anaconda3/lib/python3.6/site-packages
Please input the desired Python library path to use.  Default is [/home/carl/anaconda3/lib/python3.6/site-packages]
Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: jemalloc as malloc support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: Google Cloud Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: Hadoop File System support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Amazon S3 File System support? [Y/n]: Amazon S3 File System support will be enabled for TensorFlow.

Do you wish to build TensorFlow with XLA JIT support? [y/N]: No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with GDR support? [y/N]: No GDR support will be enabled for TensorFlow.

Do you wish to build TensorFlow with VERBS support? [y/N]: No VERBS support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: No CUDA support will be enabled for TensorFlow.

Do you wish to build TensorFlow with MPI support? [y/N]: No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: 

Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
Configuration finished
yes: standard output: Broken pipe

Step 1: Configure the subgraph to compile.

List feeds and fetches

tfcompile needs static input shapes so we have to pick a batch size for our image classifier.

In [7]:
import tf2xla_pb2

config = tf2xla_pb2.Config()

batch_size = 1

for x in model.inputs:
    x.set_shape([batch_size] + list(x.shape)[1:])
    feed = config.feed.add()
    feed.id.node_name = x.op.name
    feed.shape.MergeFrom(x.shape.as_proto())

for x in model.outputs:
    fetch = config.fetch.add()
    fetch.id.node_name = x.op.name

with open('graph.config.pbtxt', 'w') as f:
    f.write(str(config))
In [8]:
cat graph.config.pbtxt
feed {
  id {
    node_name: "input_1"
  }
  shape {
    dim {
      size: 1
    }
    dim {
      size: 224
    }
    dim {
      size: 224
    }
    dim {
      size: 3
    }
  }
}
fetch {
  id {
    node_name: "fc1000/Softmax"
  }
}

Freeze graph

The graph contains mutable nodes that have to be constants. It's possible to let tfcompile handle this for you (via freeze_graph.py) by providing a weights checkpoint along with the graph definition, but as we already have everything loaded we'll make them into constants right away.

In [10]:
session = tf.keras.backend.get_session()
output_node_names = [node.op.name for node in model.outputs]
graphdef = tf.graph_util.convert_variables_to_constants(session, session.graph_def, output_node_names)
tf.train.write_graph(graphdef, '.', 'graph.pb', as_text=False)
INFO:tensorflow:Froze 320 variables.
Converted 320 variables to const ops.
Out[10]:
'./graph.pb'

Step 2: Use the tf_library build macro to compile the subgraph.

In [11]:
%%writefile BUILD

load('@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl', 'tf_library')

tf_library(
    name = 'graph',
    config = 'graph.config.pbtxt',
    cpp_class = 'Graph',
    graph = 'graph.pb',
)
Overwriting BUILD
In [12]:
!bazel build --show_progress_rate_limit=600 @org_tensorflow//:graph
.......
Loading: 
Loading: 0 packages loaded
WARNING: /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30
Analyzing: target @org_tensorflow//:graph (68 packages loaded)
INFO: Analysed target @org_tensorflow//:graph (74 packages loaded).
Building: no action running
INFO: Found 1 target...
Building: no action running
[0 / 6] BazelWorkspaceStatusAction stable-status.txt
INFO: From Executing genrule @org_tensorflow//tensorflow/core:version_info_gen [for host]:
[1,674 / 3,309] @org_tensorflow//tensorflow/core:version_info_gen; 0s local
fatal: No names found, cannot describe anything.
[1,674 / 3,309] @org_tensorflow//tensorflow/core:version_info_gen; 0s local
INFO: From Executing genrule @org_tensorflow//:gen_graph:
[3,332 / 3,336] Executing genrule @org_tensorflow//:gen_graph; 47s local
2018-01-11 15:27:20.408071: I external/org_tensorflow/tensorflow/core/platform/s3/aws_logging.cc:53] Initializing Curl library
2018-01-11 15:27:20.514752: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
[3,332 / 3,336] Executing genrule @org_tensorflow//:gen_graph; 47s local
Target @org_tensorflow//:graph up-to-date:
[3,336 / 3,336] no action running
  bazel-bin/external/org_tensorflow/libgraph.a
[3,336 / 3,336] no action running
  bazel-bin/external/org_tensorflow/libgraph.pic.a
[3,336 / 3,336] no action running
  bazel-bin/external/org_tensorflow/libgraph.so
[3,336 / 3,336] no action running
INFO: Elapsed time: 57.837s, Critical Path: 50.33s
[3,336 / 3,336] no action running
INFO: Build completed successfully, 3 total actions
In [13]:
cat bazel-genfiles/graph.h
// Generated by tfcompile, the TensorFlow graph compiler.  DO NOT EDIT!
//
// This header was generated via ahead-of-time compilation of a TensorFlow
// graph.  An object file corresponding to this header was also generated.
// This header gives access to the functionality in that object file.
//
// clang-format off

#ifndef TFCOMPILE_GENERATED_____graph_H_  // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_____graph_H_  // NOLINT(build/header_guard)


#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"

namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; }

// (Implementation detail) Entry point to the function in the object file.
extern "C" void ____graph(
    void* result, const xla::ExecutableRunOptions* run_options,
    const void** args, void** temps, tensorflow::int64* profile_counters);


// Graph represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code. This extends the generic
// XlaCompiledCpuFunction class with statically type-safe arg and result
// methods. Usage example:
//
//   Graph computation;
//   // ...set args using computation.argN methods
//   CHECK(computation.Run());
//   // ...inspect results using computation.resultN methods
//
// The Run method invokes the actual computation, with inputs read from arg
// buffers, and outputs written to result buffers. Each Run call may also use
// a set of temporary buffers for the computation.
//
// By default each instance of this class manages its own arg, result and temp
// buffers. The AllocMode constructor parameter may be used to modify the
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
// o Calls to non-const methods require exclusive access to the object.
// o Concurrent calls to const methods are OK, if those calls are made while it
//   is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
//   (arg0: f32[1,224,224,3]) -> (f32[1,1000])
//
// Memory stats:
//   arg bytes total:    602112
//   arg bytes aligned:  602112
//   temp bytes total:   17815208
//   temp bytes aligned: 17815232
class Graph : public tensorflow::XlaCompiledCpuFunction {
 public:
  // Number of input arguments for the compiled computation.
  static constexpr size_t kNumArgs = 1;

  // Byte size of each argument buffer. There are kNumArgs entries.
  static const intptr_t* ArgSizes() {
    static constexpr intptr_t kArgSizes[kNumArgs] = {602112};
    return kArgSizes;
  }

  // Returns static data used to create an XlaCompiledCpuFunction.
  static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
    static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
      XlaCompiledCpuFunction::StaticData* data =
        new XlaCompiledCpuFunction::StaticData;
      data->raw_function = ____graph;
      data->arg_sizes = ArgSizes();
      data->num_args = kNumArgs;
      data->temp_sizes = TempSizes();
      data->num_temps = kNumTemps;
      data->result_index = kResultIndex;
      data->arg_names = StaticArgNames();
      data->result_names = StaticResultNames();
      data->program_shape = StaticProgramShape();
      return data;
    }();
    return *kStaticData;
  }

  Graph(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
      : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}

  Graph(const Graph&) = delete;
  Graph& operator=(const Graph&) = delete;

  // Arg methods for managing input buffers. Buffers are in row-major order.
  // There is a set of methods for each positional argument, with the following
  // general form:
  //
  // void set_argN_data(void* data)
  //   Sets the buffer of type T for positional argument N. May be called in
  //   any AllocMode. Must be called before Run to have an affect. Must be
  //   called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
  //   argument, to set the argument buffers.
  //
  // T* argN_data()
  //   Returns the buffer of type T for positional argument N.
  //
  // T& argN(...dim indices...)
  //   Returns a reference to the value of type T for positional argument N,
  //   with dim indices specifying which value. No bounds checking is performed
  //   on dim indices.

  void set_arg0_data(void* data) {
    set_arg_data(0, data);
  }
  float* arg0_data() {
    return static_cast<float*>(arg_data(0));
  }
  float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) {
    return (*static_cast<float(*)[1][224][224][3]>(
        arg_data(0)))[dim0][dim1][dim2][dim3];
  }
  const float* arg0_data() const {
    return static_cast<const float*>(arg_data(0));
  }
  const float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) const {
    return (*static_cast<const float(*)[1][224][224][3]>(
        arg_data(0)))[dim0][dim1][dim2][dim3];
  }

  // Result methods for managing output buffers. Buffers are in row-major order.
  // Must only be called after a successful Run call. There is a set of methods
  // for each positional result, with the following general form:
  //
  // T* resultN_data()
  //   Returns the buffer of type T for positional result N.
  //
  // T& resultN(...dim indices...)
  //   Returns a reference to the value of type T for positional result N,
  //   with dim indices specifying which value. No bounds checking is performed
  //   on dim indices.
  //
  // Unlike the arg methods, there is no set_resultN_data method. The result
  // buffers are managed internally, and may change after each call to Run.

  float* result0_data() {
    return static_cast<float*>(result_data(0));
  }
  float& result0(size_t dim0, size_t dim1) {
    return (*static_cast<float(*)[1][1000]>(
        result_data(0)))[dim0][dim1];
  }
  const float* result0_data() const {
    return static_cast<const float*>(result_data(0));
  }
  const float& result0(size_t dim0, size_t dim1) const {
    return (*static_cast<const float(*)[1][1000]>(
        result_data(0)))[dim0][dim1];
  }

 private:
  // Number of result and temporary buffers for the compiled computation.
  static constexpr size_t kNumTemps = 10;
  // The 0-based index of the result tuple in the temporary buffers.
  static constexpr size_t kResultIndex = 2;

  // Byte size of each result / temporary buffer. There are kNumTemps entries.
  static const intptr_t* TempSizes() {
    static constexpr intptr_t kTempSizes[kNumTemps] = {-1, 4000, 8, -1, -1, -1, -1, -1, -1, 17811200};
    return kTempSizes;
  }

  // Array of names of each positional argument, terminated by nullptr.
  static const char** StaticArgNames() {
    return nullptr;
  }

  // Array of names of each positional result, terminated by nullptr.
  static const char** StaticResultNames() {
    return nullptr;
  }

  // Shape of the args and results.
  static const xla::ProgramShape* StaticProgramShape() {
    return nullptr;
  }
};


#endif  // TFCOMPILE_GENERATED_____graph_H_

// clang-format on

Step 3: Write code to invoke the subgraph.

In [14]:
%%writefile graph.cc

#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL

#include "graph.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

extern "C" int run(float *input, float *output, int input_size, int output_size) {
  Eigen::ThreadPool tp(std::thread::hardware_concurrency());
  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
  Graph graph;
  graph.set_thread_pool(&device);

  std::copy(input, input + input_size, graph.arg0_data());
  auto ok = graph.Run();
  if (not ok) return -1;
  std::copy(graph.result0_data(), graph.result0_data() + output_size, output);
  return 0;
}
Writing graph.cc

Step 4: Create the final binary.

Instead of calling gcc directly, and as Bazel is already required for building the tfcompile tool, we'll make a cc_binary rule. In fact, we could just have done one big BUILD file directly after having cloned the TensorFlow repo.

In [15]:
%%writefile -a BUILD

cc_binary(
    name = "libmodel.so",
    srcs = ["graph.cc"],
    deps = [":graph", "//third_party/eigen3"],
    linkopts = ["-lpthread"],
    linkshared = 1,
    copts = ["-fPIC"],
)
Appending to BUILD
In [16]:
!bazel build --show_progress_rate_limit=60 @org_tensorflow//:libmodel.so
Loading: 
Loading: 0 packages loaded
WARNING: /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30
Analyzing: target @org_tensorflow//:libmodel.so (2 packages loaded)
INFO: Analysed target @org_tensorflow//:libmodel.so (2 packages loaded).
Building: no action running
INFO: Found 1 target...
Building: no action running
[0 / 5] BazelWorkspaceStatusAction stable-status.txt
Target @org_tensorflow//:libmodel.so up-to-date:
[632 / 632] no action running
  bazel-bin/external/org_tensorflow/libmodel.so
[632 / 632] no action running
INFO: Elapsed time: 1.852s, Critical Path: 0.56s
[632 / 632] no action running
INFO: Build completed successfully, 1 total action
In [19]:
import numpy as np

libmodel = np.ctypeslib.load_library('libmodel', 'bazel-bin/external/org_tensorflow')
libmodel.run.argtypes = [
    np.ctypeslib.ndpointer(np.float32, ndim=4, shape=(1, 224, 224, 3), flags=('c', 'a')),
    np.ctypeslib.ndpointer(np.float32, ndim=2, shape=(1, 1000), flags=('c', 'a', 'w')),
    np.ctypeslib.ctypes.c_int,
    np.ctypeslib.ctypes.c_int]


def predict(x):
    x = np.require(x, np.float32, ('c', 'a'))
    y = np.require(np.zeros((1, 1000)), np.float32, ('c', 'a', 'w'))
    libmodel.run(x, y, x.size, y.size)
    return y
In [20]:
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions

image_path = input()

x = image.img_to_array(image.load_img(image_path, target_size=(224, 224)))
x = x[None, ...]
x = preprocess_input(x)
y = predict(x)
decode_predictions(y)[0]
Out[20]:
[('n02110806', 'basenji', 0.60816735),
 ('n02441942', 'weasel', 0.10849755),
 ('n02091244', 'Ibizan_hound', 0.081580825),
 ('n02124075', 'Egyptian_cat', 0.044705715),
 ('n02123597', 'Siamese_cat', 0.025189402)]
In [36]:
%timeit model.predict(x)
%timeit predict(x)
np.testing.assert_allclose(model.predict(x), predict(x), atol=1e-5)
150 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
191 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [37]:
%%timeit
model = tf.keras.applications.ResNet50()
model.predict(x)
2.96 s ± 456 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Comments

Comments powered by Disqus