Zero-copy data sharing between JAX and TensorFlow via DLPack
In [1]:
import numpy as np
import tensorflow as tf
import jax.dlpack
tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
np.testing.assert_array_equal(tf_arr, jax_arr)
In [2]:
import jax.numpy as jnp
In [3]:
def tf_to_jax(arr):
return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))
def jax_to_tf(arr):
return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))
jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
In [4]:
jnp.all(jax_arr == jax_arr2)
Out[4]:
In [5]:
jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer()
Out[5]:
Comments
Comments powered by Disqus