tfcompile
Deploying a TensorFlow graph via XLA AOT compilation¶
Many machine learning models are deployed as cloud services where you can accommodate a full-blown runtime, but managing servers and requiring internet connectivity for your app is a hassle. Instead, you can use tfcompile (a XLA CLI tool) to compile a TensorFlow graph to executable machine code, and then deploy that as a microservice or native application.
XLA¶
XLA is a compiler of TensorFlow graphs.
- TensorFlow's graph abstraction incurs overhead.
- XLA combats this so we can afford typing high-level code without relying on the existence of custom ops kernels.
- The compiler can be used for graph optimization during model training, but we'll focus on ahead-of-time (AOT) compilation for model deployment.
- Implementation is still maturing. XLA was released march last year and there are several commits per day.
Steps for ahead-of-time compiling a graph with XLA¶
We'll use the command-line tool tfcompile via Bazel.
- Configure the subgraph to compile.
- Use the tf_library build macro to compile the subgraph.
- Write code to invoke the subgraph.
- Create the final binary.
Step 0: Model¶
Before we start compiling a graph we need to build our graph. Let's keep it simple by just loading a pretrained image classifier.
# This cell can be safely removed and doesn't need to be run.
%env CUDA_VISIBLE_DEVICES=''
import tensorflow as tf
env: CUDA_VISIBLE_DEVICES=''
import tensorflow as tf
tf.keras.backend.set_learning_phase(False)
model = tf.keras.applications.ResNet50()
model.summary(80)
________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================ input_1 (InputLayer) (None, 224, 224, 0 ________________________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 9472 input_1[0][0] ________________________________________________________________________________ bn_conv1 (BatchNormalizat (None, 112, 112, 256 conv1[0][0] ________________________________________________________________________________ activation_1 (Activation) (None, 112, 112, 0 bn_conv1[0][0] ________________________________________________________________________________ max_pooling2d_1 (MaxPooli (None, 55, 55, 64 0 activation_1[0][0] ________________________________________________________________________________ res2a_branch2a (Conv2D) (None, 55, 55, 64 4160 max_pooling2d_1[0][0] ________________________________________________________________________________ bn2a_branch2a (BatchNorma (None, 55, 55, 64 256 res2a_branch2a[0][0] ________________________________________________________________________________ activation_2 (Activation) (None, 55, 55, 64 0 bn2a_branch2a[0][0] ________________________________________________________________________________ res2a_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_2[0][0] ________________________________________________________________________________ bn2a_branch2b (BatchNorma (None, 55, 55, 64 256 res2a_branch2b[0][0] ________________________________________________________________________________ activation_3 (Activation) (None, 55, 55, 64 0 bn2a_branch2b[0][0] ________________________________________________________________________________ res2a_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_3[0][0] ________________________________________________________________________________ res2a_branch1 (Conv2D) (None, 55, 55, 25 16640 max_pooling2d_1[0][0] ________________________________________________________________________________ bn2a_branch2c (BatchNorma (None, 55, 55, 25 1024 res2a_branch2c[0][0] ________________________________________________________________________________ bn2a_branch1 (BatchNormal (None, 55, 55, 25 1024 res2a_branch1[0][0] ________________________________________________________________________________ add_1 (Add) (None, 55, 55, 25 0 bn2a_branch2c[0][0] bn2a_branch1[0][0] ________________________________________________________________________________ activation_4 (Activation) (None, 55, 55, 25 0 add_1[0][0] ________________________________________________________________________________ res2b_branch2a (Conv2D) (None, 55, 55, 64 16448 activation_4[0][0] ________________________________________________________________________________ bn2b_branch2a (BatchNorma (None, 55, 55, 64 256 res2b_branch2a[0][0] ________________________________________________________________________________ activation_5 (Activation) (None, 55, 55, 64 0 bn2b_branch2a[0][0] ________________________________________________________________________________ res2b_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_5[0][0] ________________________________________________________________________________ bn2b_branch2b (BatchNorma (None, 55, 55, 64 256 res2b_branch2b[0][0] ________________________________________________________________________________ activation_6 (Activation) (None, 55, 55, 64 0 bn2b_branch2b[0][0] ________________________________________________________________________________ res2b_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_6[0][0] ________________________________________________________________________________ bn2b_branch2c (BatchNorma (None, 55, 55, 25 1024 res2b_branch2c[0][0] ________________________________________________________________________________ add_2 (Add) (None, 55, 55, 25 0 bn2b_branch2c[0][0] activation_4[0][0] ________________________________________________________________________________ activation_7 (Activation) (None, 55, 55, 25 0 add_2[0][0] ________________________________________________________________________________ res2c_branch2a (Conv2D) (None, 55, 55, 64 16448 activation_7[0][0] ________________________________________________________________________________ bn2c_branch2a (BatchNorma (None, 55, 55, 64 256 res2c_branch2a[0][0] ________________________________________________________________________________ activation_8 (Activation) (None, 55, 55, 64 0 bn2c_branch2a[0][0] ________________________________________________________________________________ res2c_branch2b (Conv2D) (None, 55, 55, 64 36928 activation_8[0][0] ________________________________________________________________________________ bn2c_branch2b (BatchNorma (None, 55, 55, 64 256 res2c_branch2b[0][0] ________________________________________________________________________________ activation_9 (Activation) (None, 55, 55, 64 0 bn2c_branch2b[0][0] ________________________________________________________________________________ res2c_branch2c (Conv2D) (None, 55, 55, 25 16640 activation_9[0][0] ________________________________________________________________________________ bn2c_branch2c (BatchNorma (None, 55, 55, 25 1024 res2c_branch2c[0][0] ________________________________________________________________________________ add_3 (Add) (None, 55, 55, 25 0 bn2c_branch2c[0][0] activation_7[0][0] ________________________________________________________________________________ activation_10 (Activation (None, 55, 55, 25 0 add_3[0][0] ________________________________________________________________________________ res3a_branch2a (Conv2D) (None, 28, 28, 12 32896 activation_10[0][0] ________________________________________________________________________________ bn3a_branch2a (BatchNorma (None, 28, 28, 12 512 res3a_branch2a[0][0] ________________________________________________________________________________ activation_11 (Activation (None, 28, 28, 12 0 bn3a_branch2a[0][0] ________________________________________________________________________________ res3a_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_11[0][0] ________________________________________________________________________________ bn3a_branch2b (BatchNorma (None, 28, 28, 12 512 res3a_branch2b[0][0] ________________________________________________________________________________ activation_12 (Activation (None, 28, 28, 12 0 bn3a_branch2b[0][0] ________________________________________________________________________________ res3a_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_12[0][0] ________________________________________________________________________________ res3a_branch1 (Conv2D) (None, 28, 28, 51 131584 activation_10[0][0] ________________________________________________________________________________ bn3a_branch2c (BatchNorma (None, 28, 28, 51 2048 res3a_branch2c[0][0] ________________________________________________________________________________ bn3a_branch1 (BatchNormal (None, 28, 28, 51 2048 res3a_branch1[0][0] ________________________________________________________________________________ add_4 (Add) (None, 28, 28, 51 0 bn3a_branch2c[0][0] bn3a_branch1[0][0] ________________________________________________________________________________ activation_13 (Activation (None, 28, 28, 51 0 add_4[0][0] ________________________________________________________________________________ res3b_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_13[0][0] ________________________________________________________________________________ bn3b_branch2a (BatchNorma (None, 28, 28, 12 512 res3b_branch2a[0][0] ________________________________________________________________________________ activation_14 (Activation (None, 28, 28, 12 0 bn3b_branch2a[0][0] ________________________________________________________________________________ res3b_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_14[0][0] ________________________________________________________________________________ bn3b_branch2b (BatchNorma (None, 28, 28, 12 512 res3b_branch2b[0][0] ________________________________________________________________________________ activation_15 (Activation (None, 28, 28, 12 0 bn3b_branch2b[0][0] ________________________________________________________________________________ res3b_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_15[0][0] ________________________________________________________________________________ bn3b_branch2c (BatchNorma (None, 28, 28, 51 2048 res3b_branch2c[0][0] ________________________________________________________________________________ add_5 (Add) (None, 28, 28, 51 0 bn3b_branch2c[0][0] activation_13[0][0] ________________________________________________________________________________ activation_16 (Activation (None, 28, 28, 51 0 add_5[0][0] ________________________________________________________________________________ res3c_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_16[0][0] ________________________________________________________________________________ bn3c_branch2a (BatchNorma (None, 28, 28, 12 512 res3c_branch2a[0][0] ________________________________________________________________________________ activation_17 (Activation (None, 28, 28, 12 0 bn3c_branch2a[0][0] ________________________________________________________________________________ res3c_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_17[0][0] ________________________________________________________________________________ bn3c_branch2b (BatchNorma (None, 28, 28, 12 512 res3c_branch2b[0][0] ________________________________________________________________________________ activation_18 (Activation (None, 28, 28, 12 0 bn3c_branch2b[0][0] ________________________________________________________________________________ res3c_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_18[0][0] ________________________________________________________________________________ bn3c_branch2c (BatchNorma (None, 28, 28, 51 2048 res3c_branch2c[0][0] ________________________________________________________________________________ add_6 (Add) (None, 28, 28, 51 0 bn3c_branch2c[0][0] activation_16[0][0] ________________________________________________________________________________ activation_19 (Activation (None, 28, 28, 51 0 add_6[0][0] ________________________________________________________________________________ res3d_branch2a (Conv2D) (None, 28, 28, 12 65664 activation_19[0][0] ________________________________________________________________________________ bn3d_branch2a (BatchNorma (None, 28, 28, 12 512 res3d_branch2a[0][0] ________________________________________________________________________________ activation_20 (Activation (None, 28, 28, 12 0 bn3d_branch2a[0][0] ________________________________________________________________________________ res3d_branch2b (Conv2D) (None, 28, 28, 12 147584 activation_20[0][0] ________________________________________________________________________________ bn3d_branch2b (BatchNorma (None, 28, 28, 12 512 res3d_branch2b[0][0] ________________________________________________________________________________ activation_21 (Activation (None, 28, 28, 12 0 bn3d_branch2b[0][0] ________________________________________________________________________________ res3d_branch2c (Conv2D) (None, 28, 28, 51 66048 activation_21[0][0] ________________________________________________________________________________ bn3d_branch2c (BatchNorma (None, 28, 28, 51 2048 res3d_branch2c[0][0] ________________________________________________________________________________ add_7 (Add) (None, 28, 28, 51 0 bn3d_branch2c[0][0] activation_19[0][0] ________________________________________________________________________________ activation_22 (Activation (None, 28, 28, 51 0 add_7[0][0] ________________________________________________________________________________ res4a_branch2a (Conv2D) (None, 14, 14, 25 131328 activation_22[0][0] ________________________________________________________________________________ bn4a_branch2a (BatchNorma (None, 14, 14, 25 1024 res4a_branch2a[0][0] ________________________________________________________________________________ activation_23 (Activation (None, 14, 14, 25 0 bn4a_branch2a[0][0] ________________________________________________________________________________ res4a_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_23[0][0] ________________________________________________________________________________ bn4a_branch2b (BatchNorma (None, 14, 14, 25 1024 res4a_branch2b[0][0] ________________________________________________________________________________ activation_24 (Activation (None, 14, 14, 25 0 bn4a_branch2b[0][0] ________________________________________________________________________________ res4a_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_24[0][0] ________________________________________________________________________________ res4a_branch1 (Conv2D) (None, 14, 14, 10 525312 activation_22[0][0] ________________________________________________________________________________ bn4a_branch2c (BatchNorma (None, 14, 14, 10 4096 res4a_branch2c[0][0] ________________________________________________________________________________ bn4a_branch1 (BatchNormal (None, 14, 14, 10 4096 res4a_branch1[0][0] ________________________________________________________________________________ add_8 (Add) (None, 14, 14, 10 0 bn4a_branch2c[0][0] bn4a_branch1[0][0] ________________________________________________________________________________ activation_25 (Activation (None, 14, 14, 10 0 add_8[0][0] ________________________________________________________________________________ res4b_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_25[0][0] ________________________________________________________________________________ bn4b_branch2a (BatchNorma (None, 14, 14, 25 1024 res4b_branch2a[0][0] ________________________________________________________________________________ activation_26 (Activation (None, 14, 14, 25 0 bn4b_branch2a[0][0] ________________________________________________________________________________ res4b_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_26[0][0] ________________________________________________________________________________ bn4b_branch2b (BatchNorma (None, 14, 14, 25 1024 res4b_branch2b[0][0] ________________________________________________________________________________ activation_27 (Activation (None, 14, 14, 25 0 bn4b_branch2b[0][0] ________________________________________________________________________________ res4b_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_27[0][0] ________________________________________________________________________________ bn4b_branch2c (BatchNorma (None, 14, 14, 10 4096 res4b_branch2c[0][0] ________________________________________________________________________________ add_9 (Add) (None, 14, 14, 10 0 bn4b_branch2c[0][0] activation_25[0][0] ________________________________________________________________________________ activation_28 (Activation (None, 14, 14, 10 0 add_9[0][0] ________________________________________________________________________________ res4c_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_28[0][0] ________________________________________________________________________________ bn4c_branch2a (BatchNorma (None, 14, 14, 25 1024 res4c_branch2a[0][0] ________________________________________________________________________________ activation_29 (Activation (None, 14, 14, 25 0 bn4c_branch2a[0][0] ________________________________________________________________________________ res4c_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_29[0][0] ________________________________________________________________________________ bn4c_branch2b (BatchNorma (None, 14, 14, 25 1024 res4c_branch2b[0][0] ________________________________________________________________________________ activation_30 (Activation (None, 14, 14, 25 0 bn4c_branch2b[0][0] ________________________________________________________________________________ res4c_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_30[0][0] ________________________________________________________________________________ bn4c_branch2c (BatchNorma (None, 14, 14, 10 4096 res4c_branch2c[0][0] ________________________________________________________________________________ add_10 (Add) (None, 14, 14, 10 0 bn4c_branch2c[0][0] activation_28[0][0] ________________________________________________________________________________ activation_31 (Activation (None, 14, 14, 10 0 add_10[0][0] ________________________________________________________________________________ res4d_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_31[0][0] ________________________________________________________________________________ bn4d_branch2a (BatchNorma (None, 14, 14, 25 1024 res4d_branch2a[0][0] ________________________________________________________________________________ activation_32 (Activation (None, 14, 14, 25 0 bn4d_branch2a[0][0] ________________________________________________________________________________ res4d_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_32[0][0] ________________________________________________________________________________ bn4d_branch2b (BatchNorma (None, 14, 14, 25 1024 res4d_branch2b[0][0] ________________________________________________________________________________ activation_33 (Activation (None, 14, 14, 25 0 bn4d_branch2b[0][0] ________________________________________________________________________________ res4d_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_33[0][0] ________________________________________________________________________________ bn4d_branch2c (BatchNorma (None, 14, 14, 10 4096 res4d_branch2c[0][0] ________________________________________________________________________________ add_11 (Add) (None, 14, 14, 10 0 bn4d_branch2c[0][0] activation_31[0][0] ________________________________________________________________________________ activation_34 (Activation (None, 14, 14, 10 0 add_11[0][0] ________________________________________________________________________________ res4e_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_34[0][0] ________________________________________________________________________________ bn4e_branch2a (BatchNorma (None, 14, 14, 25 1024 res4e_branch2a[0][0] ________________________________________________________________________________ activation_35 (Activation (None, 14, 14, 25 0 bn4e_branch2a[0][0] ________________________________________________________________________________ res4e_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_35[0][0] ________________________________________________________________________________ bn4e_branch2b (BatchNorma (None, 14, 14, 25 1024 res4e_branch2b[0][0] ________________________________________________________________________________ activation_36 (Activation (None, 14, 14, 25 0 bn4e_branch2b[0][0] ________________________________________________________________________________ res4e_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_36[0][0] ________________________________________________________________________________ bn4e_branch2c (BatchNorma (None, 14, 14, 10 4096 res4e_branch2c[0][0] ________________________________________________________________________________ add_12 (Add) (None, 14, 14, 10 0 bn4e_branch2c[0][0] activation_34[0][0] ________________________________________________________________________________ activation_37 (Activation (None, 14, 14, 10 0 add_12[0][0] ________________________________________________________________________________ res4f_branch2a (Conv2D) (None, 14, 14, 25 262400 activation_37[0][0] ________________________________________________________________________________ bn4f_branch2a (BatchNorma (None, 14, 14, 25 1024 res4f_branch2a[0][0] ________________________________________________________________________________ activation_38 (Activation (None, 14, 14, 25 0 bn4f_branch2a[0][0] ________________________________________________________________________________ res4f_branch2b (Conv2D) (None, 14, 14, 25 590080 activation_38[0][0] ________________________________________________________________________________ bn4f_branch2b (BatchNorma (None, 14, 14, 25 1024 res4f_branch2b[0][0] ________________________________________________________________________________ activation_39 (Activation (None, 14, 14, 25 0 bn4f_branch2b[0][0] ________________________________________________________________________________ res4f_branch2c (Conv2D) (None, 14, 14, 10 263168 activation_39[0][0] ________________________________________________________________________________ bn4f_branch2c (BatchNorma (None, 14, 14, 10 4096 res4f_branch2c[0][0] ________________________________________________________________________________ add_13 (Add) (None, 14, 14, 10 0 bn4f_branch2c[0][0] activation_37[0][0] ________________________________________________________________________________ activation_40 (Activation (None, 14, 14, 10 0 add_13[0][0] ________________________________________________________________________________ res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0] ________________________________________________________________________________ bn5a_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5a_branch2a[0][0] ________________________________________________________________________________ activation_41 (Activation (None, 7, 7, 512) 0 bn5a_branch2a[0][0] ________________________________________________________________________________ res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0] ________________________________________________________________________________ bn5a_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5a_branch2b[0][0] ________________________________________________________________________________ activation_42 (Activation (None, 7, 7, 512) 0 bn5a_branch2b[0][0] ________________________________________________________________________________ res5a_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_42[0][0] ________________________________________________________________________________ res5a_branch1 (Conv2D) (None, 7, 7, 2048 2099200 activation_40[0][0] ________________________________________________________________________________ bn5a_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5a_branch2c[0][0] ________________________________________________________________________________ bn5a_branch1 (BatchNormal (None, 7, 7, 2048 8192 res5a_branch1[0][0] ________________________________________________________________________________ add_14 (Add) (None, 7, 7, 2048 0 bn5a_branch2c[0][0] bn5a_branch1[0][0] ________________________________________________________________________________ activation_43 (Activation (None, 7, 7, 2048 0 add_14[0][0] ________________________________________________________________________________ res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0] ________________________________________________________________________________ bn5b_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5b_branch2a[0][0] ________________________________________________________________________________ activation_44 (Activation (None, 7, 7, 512) 0 bn5b_branch2a[0][0] ________________________________________________________________________________ res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0] ________________________________________________________________________________ bn5b_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5b_branch2b[0][0] ________________________________________________________________________________ activation_45 (Activation (None, 7, 7, 512) 0 bn5b_branch2b[0][0] ________________________________________________________________________________ res5b_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_45[0][0] ________________________________________________________________________________ bn5b_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5b_branch2c[0][0] ________________________________________________________________________________ add_15 (Add) (None, 7, 7, 2048 0 bn5b_branch2c[0][0] activation_43[0][0] ________________________________________________________________________________ activation_46 (Activation (None, 7, 7, 2048 0 add_15[0][0] ________________________________________________________________________________ res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0] ________________________________________________________________________________ bn5c_branch2a (BatchNorma (None, 7, 7, 512) 2048 res5c_branch2a[0][0] ________________________________________________________________________________ activation_47 (Activation (None, 7, 7, 512) 0 bn5c_branch2a[0][0] ________________________________________________________________________________ res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0] ________________________________________________________________________________ bn5c_branch2b (BatchNorma (None, 7, 7, 512) 2048 res5c_branch2b[0][0] ________________________________________________________________________________ activation_48 (Activation (None, 7, 7, 512) 0 bn5c_branch2b[0][0] ________________________________________________________________________________ res5c_branch2c (Conv2D) (None, 7, 7, 2048 1050624 activation_48[0][0] ________________________________________________________________________________ bn5c_branch2c (BatchNorma (None, 7, 7, 2048 8192 res5c_branch2c[0][0] ________________________________________________________________________________ add_16 (Add) (None, 7, 7, 2048 0 bn5c_branch2c[0][0] activation_46[0][0] ________________________________________________________________________________ activation_49 (Activation (None, 7, 7, 2048 0 add_16[0][0] ________________________________________________________________________________ avg_pool (AveragePooling2 (None, 1, 1, 2048 0 activation_49[0][0] ________________________________________________________________________________ flatten_1 (Flatten) (None, 2048) 0 avg_pool[0][0] ________________________________________________________________________________ fc1000 (Dense) (None, 1000) 2049000 flatten_1[0][0] ================================================================================ Total params: 25,636,712 Trainable params: 25,583,592 Non-trainable params: 53,120 ________________________________________________________________________________
Step 0.5: Download tfcompile¶
XLA is still maturing and as of now we have to checkout the development release. System prerequisites are git, the build tool Bazel and the Protocol Buffers compiler. I'm also assuming we're running tf-nightly which can be installed via pip.
%rm -rf /tmp/tensorflow
%cd /tmp
!git clone --depth=1 --single-branch https://github.com/tensorflow/tensorflow
%cd tensorflow
!yes "" | ./configure
!protoc tensorflow/compiler/tf2xla/tf2xla.proto --python_out=.
!cp tensorflow/compiler/tf2xla/tf2xla_pb2.py .
/tmp Cloning into 'tensorflow'... remote: Counting objects: 10580, done. remote: Compressing objects: 100% (8825/8825), done. remote: Total 10580 (delta 3329), reused 3594 (delta 1486), pack-reused 0 Receiving objects: 100% (10580/10580), 21.65 MiB | 4.71 MiB/s, done. Resolving deltas: 100% (3329/3329), done. /tmp/tensorflow WARNING: Running Bazel server needs to be killed, because the startup options are different. You have bazel 0.8.1 installed. Please specify the location of python. [Default is /home/carl/anaconda3/bin/python]: Found possible Python library paths: /home/carl/anaconda3/lib/python3.6/site-packages Please input the desired Python library path to use. Default is [/home/carl/anaconda3/lib/python3.6/site-packages] Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: jemalloc as malloc support will be enabled for TensorFlow. Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: Google Cloud Platform support will be enabled for TensorFlow. Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: Hadoop File System support will be enabled for TensorFlow. Do you wish to build TensorFlow with Amazon S3 File System support? [Y/n]: Amazon S3 File System support will be enabled for TensorFlow. Do you wish to build TensorFlow with XLA JIT support? [y/N]: No XLA JIT support will be enabled for TensorFlow. Do you wish to build TensorFlow with GDR support? [y/N]: No GDR support will be enabled for TensorFlow. Do you wish to build TensorFlow with VERBS support? [y/N]: No VERBS support will be enabled for TensorFlow. Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: No OpenCL SYCL support will be enabled for TensorFlow. Do you wish to build TensorFlow with CUDA support? [y/N]: No CUDA support will be enabled for TensorFlow. Do you wish to build TensorFlow with MPI support? [y/N]: No MPI support will be enabled for TensorFlow. Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: Not configuring the WORKSPACE for Android builds. Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details. --config=mkl # Build with MKL support. --config=monolithic # Config for mostly static monolithic build. Configuration finished yes: standard output: Broken pipe
Step 1: Configure the subgraph to compile.¶
List feeds and fetches¶
tfcompile needs static input shapes so we have to pick a batch size for our image classifier.
import tf2xla_pb2
config = tf2xla_pb2.Config()
batch_size = 1
for x in model.inputs:
x.set_shape([batch_size] + list(x.shape)[1:])
feed = config.feed.add()
feed.id.node_name = x.op.name
feed.shape.MergeFrom(x.shape.as_proto())
for x in model.outputs:
fetch = config.fetch.add()
fetch.id.node_name = x.op.name
with open('graph.config.pbtxt', 'w') as f:
f.write(str(config))
cat graph.config.pbtxt
feed { id { node_name: "input_1" } shape { dim { size: 1 } dim { size: 224 } dim { size: 224 } dim { size: 3 } } } fetch { id { node_name: "fc1000/Softmax" } }
Freeze graph¶
The graph contains mutable nodes that have to be constants. It's possible to let tfcompile handle this for you (via freeze_graph.py) by providing a weights checkpoint along with the graph definition, but as we already have everything loaded we'll make them into constants right away.
session = tf.keras.backend.get_session()
output_node_names = [node.op.name for node in model.outputs]
graphdef = tf.graph_util.convert_variables_to_constants(session, session.graph_def, output_node_names)
tf.train.write_graph(graphdef, '.', 'graph.pb', as_text=False)
INFO:tensorflow:Froze 320 variables. Converted 320 variables to const ops.
'./graph.pb'
Step 2: Use the tf_library build macro to compile the subgraph.¶
%%writefile BUILD
load('@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl', 'tf_library')
tf_library(
name = 'graph',
config = 'graph.config.pbtxt',
cpp_class = 'Graph',
graph = 'graph.pb',
)
Overwriting BUILD
!bazel build --show_progress_rate_limit=600 @org_tensorflow//:graph
....... Loading: Loading: 0 packages loaded WARNING: /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30 Analyzing: target @org_tensorflow//:graph (68 packages loaded) INFO: Analysed target @org_tensorflow//:graph (74 packages loaded). Building: no action running INFO: Found 1 target... Building: no action running [0 / 6] BazelWorkspaceStatusAction stable-status.txt INFO: From Executing genrule @org_tensorflow//tensorflow/core:version_info_gen [for host]: [1,674 / 3,309] @org_tensorflow//tensorflow/core:version_info_gen; 0s local fatal: No names found, cannot describe anything. [1,674 / 3,309] @org_tensorflow//tensorflow/core:version_info_gen; 0s local INFO: From Executing genrule @org_tensorflow//:gen_graph: [3,332 / 3,336] Executing genrule @org_tensorflow//:gen_graph; 47s local 2018-01-11 15:27:20.408071: I external/org_tensorflow/tensorflow/core/platform/s3/aws_logging.cc:53] Initializing Curl library 2018-01-11 15:27:20.514752: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA [3,332 / 3,336] Executing genrule @org_tensorflow//:gen_graph; 47s local Target @org_tensorflow//:graph up-to-date: [3,336 / 3,336] no action running bazel-bin/external/org_tensorflow/libgraph.a [3,336 / 3,336] no action running bazel-bin/external/org_tensorflow/libgraph.pic.a [3,336 / 3,336] no action running bazel-bin/external/org_tensorflow/libgraph.so [3,336 / 3,336] no action running INFO: Elapsed time: 57.837s, Critical Path: 50.33s [3,336 / 3,336] no action running INFO: Build completed successfully, 3 total actions
cat bazel-genfiles/graph.h
// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT! // // This header was generated via ahead-of-time compilation of a TensorFlow // graph. An object file corresponding to this header was also generated. // This header gives access to the functionality in that object file. // // clang-format off #ifndef TFCOMPILE_GENERATED_____graph_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_____graph_H_ // NOLINT(build/header_guard) #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void ____graph( void* result, const xla::ExecutableRunOptions* run_options, const void** args, void** temps, tensorflow::int64* profile_counters); // Graph represents a computation previously specified in a // TensorFlow graph, now compiled into executable code. This extends the generic // XlaCompiledCpuFunction class with statically type-safe arg and result // methods. Usage example: // // Graph computation; // // ...set args using computation.argN methods // CHECK(computation.Run()); // // ...inspect results using computation.resultN methods // // The Run method invokes the actual computation, with inputs read from arg // buffers, and outputs written to result buffers. Each Run call may also use // a set of temporary buffers for the computation. // // By default each instance of this class manages its own arg, result and temp // buffers. The AllocMode constructor parameter may be used to modify the // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: // o Calls to non-const methods require exclusive access to the object. // o Concurrent calls to const methods are OK, if those calls are made while it // is guaranteed that no thread may call a non-const method. // // The logical function signature is: // (arg0: f32[1,224,224,3]) -> (f32[1,1000]) // // Memory stats: // arg bytes total: 602112 // arg bytes aligned: 602112 // temp bytes total: 17815208 // temp bytes aligned: 17815232 class Graph : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 1; // Byte size of each argument buffer. There are kNumArgs entries. static const intptr_t* ArgSizes() { static constexpr intptr_t kArgSizes[kNumArgs] = {602112}; return kArgSizes; } // Returns static data used to create an XlaCompiledCpuFunction. static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; data->raw_function = ____graph; data->arg_sizes = ArgSizes(); data->num_args = kNumArgs; data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); return data; }(); return *kStaticData; } Graph(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} Graph(const Graph&) = delete; Graph& operator=(const Graph&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following // general form: // // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. // // T& argN(...dim indices...) // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. void set_arg0_data(void* data) { set_arg_data(0, data); } float* arg0_data() { return static_cast<float*>(arg_data(0)); } float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) { return (*static_cast<float(*)[1][224][224][3]>( arg_data(0)))[dim0][dim1][dim2][dim3]; } const float* arg0_data() const { return static_cast<const float*>(arg_data(0)); } const float& arg0(size_t dim0, size_t dim1, size_t dim2, size_t dim3) const { return (*static_cast<const float(*)[1][224][224][3]>( arg_data(0)))[dim0][dim1][dim2][dim3]; } // Result methods for managing output buffers. Buffers are in row-major order. // Must only be called after a successful Run call. There is a set of methods // for each positional result, with the following general form: // // T* resultN_data() // Returns the buffer of type T for positional result N. // // T& resultN(...dim indices...) // Returns a reference to the value of type T for positional result N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. float* result0_data() { return static_cast<float*>(result_data(0)); } float& result0(size_t dim0, size_t dim1) { return (*static_cast<float(*)[1][1000]>( result_data(0)))[dim0][dim1]; } const float* result0_data() const { return static_cast<const float*>(result_data(0)); } const float& result0(size_t dim0, size_t dim1) const { return (*static_cast<const float(*)[1][1000]>( result_data(0)))[dim0][dim1]; } private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = 10; // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = 2; // Byte size of each result / temporary buffer. There are kNumTemps entries. static const intptr_t* TempSizes() { static constexpr intptr_t kTempSizes[kNumTemps] = {-1, 4000, 8, -1, -1, -1, -1, -1, -1, 17811200}; return kTempSizes; } // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { return nullptr; } // Array of names of each positional result, terminated by nullptr. static const char** StaticResultNames() { return nullptr; } // Shape of the args and results. static const xla::ProgramShape* StaticProgramShape() { return nullptr; } }; #endif // TFCOMPILE_GENERATED_____graph_H_ // clang-format on
Step 3: Write code to invoke the subgraph.¶
%%writefile graph.cc
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include "graph.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
extern "C" int run(float *input, float *output, int input_size, int output_size) {
Eigen::ThreadPool tp(std::thread::hardware_concurrency());
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
Graph graph;
graph.set_thread_pool(&device);
std::copy(input, input + input_size, graph.arg0_data());
auto ok = graph.Run();
if (not ok) return -1;
std::copy(graph.result0_data(), graph.result0_data() + output_size, output);
return 0;
}
Writing graph.cc
Step 4: Create the final binary.¶
Instead of calling gcc
directly, and as Bazel is already required for building the tfcompile tool, we'll make a cc_binary
rule. In fact, we could just have done one big BUILD file directly after having cloned the TensorFlow repo.
%%writefile -a BUILD
cc_binary(
name = "libmodel.so",
srcs = ["graph.cc"],
deps = [":graph", "//third_party/eigen3"],
linkopts = ["-lpthread"],
linkshared = 1,
copts = ["-fPIC"],
)
Appending to BUILD
!bazel build --show_progress_rate_limit=60 @org_tensorflow//:libmodel.so
Loading: Loading: 0 packages loaded WARNING: /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/core/BUILD:1816:1: in includes attribute of cc_library rule @org_tensorflow//tensorflow/core:framework_headers_lib: '../../../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'external/org_tensorflow/tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/carl/.cache/bazel/_bazel_carl/e5cce820cc082410b4fcc604db349066/external/org_tensorflow/tensorflow/tensorflow.bzl:1143:30 Analyzing: target @org_tensorflow//:libmodel.so (2 packages loaded) INFO: Analysed target @org_tensorflow//:libmodel.so (2 packages loaded). Building: no action running INFO: Found 1 target... Building: no action running [0 / 5] BazelWorkspaceStatusAction stable-status.txt Target @org_tensorflow//:libmodel.so up-to-date: [632 / 632] no action running bazel-bin/external/org_tensorflow/libmodel.so [632 / 632] no action running INFO: Elapsed time: 1.852s, Critical Path: 0.56s [632 / 632] no action running INFO: Build completed successfully, 1 total action
import numpy as np
libmodel = np.ctypeslib.load_library('libmodel', 'bazel-bin/external/org_tensorflow')
libmodel.run.argtypes = [
np.ctypeslib.ndpointer(np.float32, ndim=4, shape=(1, 224, 224, 3), flags=('c', 'a')),
np.ctypeslib.ndpointer(np.float32, ndim=2, shape=(1, 1000), flags=('c', 'a', 'w')),
np.ctypeslib.ctypes.c_int,
np.ctypeslib.ctypes.c_int]
def predict(x):
x = np.require(x, np.float32, ('c', 'a'))
y = np.require(np.zeros((1, 1000)), np.float32, ('c', 'a', 'w'))
libmodel.run(x, y, x.size, y.size)
return y
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
image_path = input()
x = image.img_to_array(image.load_img(image_path, target_size=(224, 224)))
x = x[None, ...]
x = preprocess_input(x)
y = predict(x)
decode_predictions(y)[0]
[('n02110806', 'basenji', 0.60816735), ('n02441942', 'weasel', 0.10849755), ('n02091244', 'Ibizan_hound', 0.081580825), ('n02124075', 'Egyptian_cat', 0.044705715), ('n02123597', 'Siamese_cat', 0.025189402)]
%timeit model.predict(x)
%timeit predict(x)
np.testing.assert_allclose(model.predict(x), predict(x), atol=1e-5)
150 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 191 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
model = tf.keras.applications.ResNet50()
model.predict(x)
2.96 s ± 456 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Comments
Comments powered by Disqus