This is an automated email from the ASF dual-hosted git repository.

felipecrv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9fa78d0d1e GH-44229: [Docs] Add PyArrow to JAX example to the docs 
(#44230)
9fa78d0d1e is described below

commit 9fa78d0d1e2cc22086e8d79afba518710a1e657e
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Wed Sep 25 21:20:47 2024 -0300

    GH-44229: [Docs] Add PyArrow to JAX example to the docs (#44230)
    
    ### Rationale for this change
    
    Explicitly mention in the docs a way that PyArrow can interop with 
[JAX](https://github.com/jax-ml/jax).
    
    ### What changes are included in this PR?
    
     - Tweaks to the phrasing
     - Two JAX examples: one for `jax.numpy` and another for `jax.dlpack`
    
    ### Are these changes tested?
    
    N/A
    * GitHub Issue: #44229
    
    Authored-by: Felipe Oliveira Carvalho <[email protected]>
    Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
---
 docs/source/python/dlpack.rst | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/docs/source/python/dlpack.rst b/docs/source/python/dlpack.rst
index 024c2800e1..9f0d3b58aa 100644
--- a/docs/source/python/dlpack.rst
+++ b/docs/source/python/dlpack.rst
@@ -63,10 +63,10 @@ PyArrow implements the second part of the protocol
 (``__dlpack__(self, stream=None)`` and ``__dlpack_device__``) and can
 thus be consumed by libraries implementing ``from_dlpack``.
 
-Example
--------
+Examples
+--------
 
-Convert a PyArrow CPU array to NumPy array:
+Convert a PyArrow CPU array into a NumPy array:
 
 .. code-block::
 
@@ -84,10 +84,20 @@ Convert a PyArrow CPU array to NumPy array:
     >>> np.from_dlpack(array)
     array([2, 0, 2, 4])
 
-Convert a PyArrow CPU array to PyTorch tensor:
+Convert a PyArrow CPU array into a PyTorch tensor:
 
 .. code-block::
 
     >>> import torch
     >>> torch.from_dlpack(array)
     tensor([2, 0, 2, 4])
+
+Convert a PyArrow CPU array into a JAX array:
+
+.. code-block::
+
+    >>> import jax
+    >>> jax.numpy.from_dlpack(array)
+    Array([2, 0, 2, 4], dtype=int32)
+    >>> jax.dlpack.from_dlpack(array)
+    Array([2, 0, 2, 4], dtype=int32)

Reply via email to