Skip to main content

Zero-copy data sharing between JAX and TensorFlow via DLPack

In [1]:
import numpy as np
import tensorflow as tf
import jax.dlpack

tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)

np.testing.assert_array_equal(tf_arr, jax_arr)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [2]:
import jax.numpy as jnp
In [3]:
def tf_to_jax(arr):
  return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))

def jax_to_tf(arr):
  return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))

jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
In [4]:
jnp.all(jax_arr == jax_arr2)
Out[4]:
DeviceArray(True, dtype=bool)
In [5]:
jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer()
Out[5]:
True