This is an automated email from the ASF dual-hosted git repository.
felipecrv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 9fa78d0d1e GH-44229: [Docs] Add PyArrow to JAX example to the docs
(#44230)
9fa78d0d1e is described below
commit 9fa78d0d1e2cc22086e8d79afba518710a1e657e
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Wed Sep 25 21:20:47 2024 -0300
GH-44229: [Docs] Add PyArrow to JAX example to the docs (#44230)
### Rationale for this change
Explicitly mention in the docs a way that PyArrow can interop with
[JAX](https://github.com/jax-ml/jax).
### What changes are included in this PR?
- Tweaks to the phrasing
- Two JAX examples: one for `jax.numpy` and another for `jax.dlpack`
### Are these changes tested?
N/A
* GitHub Issue: #44229
Authored-by: Felipe Oliveira Carvalho <[email protected]>
Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
---
docs/source/python/dlpack.rst | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/docs/source/python/dlpack.rst b/docs/source/python/dlpack.rst
index 024c2800e1..9f0d3b58aa 100644
--- a/docs/source/python/dlpack.rst
+++ b/docs/source/python/dlpack.rst
@@ -63,10 +63,10 @@ PyArrow implements the second part of the protocol
(``__dlpack__(self, stream=None)`` and ``__dlpack_device__``) and can
thus be consumed by libraries implementing ``from_dlpack``.
-Example
--------
+Examples
+--------
-Convert a PyArrow CPU array to NumPy array:
+Convert a PyArrow CPU array into a NumPy array:
.. code-block::
@@ -84,10 +84,20 @@ Convert a PyArrow CPU array to NumPy array:
>>> np.from_dlpack(array)
array([2, 0, 2, 4])
-Convert a PyArrow CPU array to PyTorch tensor:
+Convert a PyArrow CPU array into a PyTorch tensor:
.. code-block::
>>> import torch
>>> torch.from_dlpack(array)
tensor([2, 0, 2, 4])
+
+Convert a PyArrow CPU array into a JAX array:
+
+.. code-block::
+
+ >>> import jax
+ >>> jax.numpy.from_dlpack(array)
+ Array([2, 0, 2, 4], dtype=int32)
+ >>> jax.dlpack.from_dlpack(array)
+ Array([2, 0, 2, 4], dtype=int32)