This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 75e4d1d  [API NEW][ARRAY METHOD] Add __Index__() and 
__array_namespace__() (#20689)
75e4d1d is described below

commit 75e4d1d3e41cc5aef9a0141e40127a75e02ccd04
Author: Zhenghui Jin <[email protected]>
AuthorDate: Mon Nov 1 12:21:41 2021 -0700

    [API NEW][ARRAY METHOD] Add __Index__() and __array_namespace__() (#20689)
    
    * [API] Add method __index__() and __array_namespace__()
    
    * update doc
    
    * fix lint
    
    * add tests
    
    * update tests
---
 docs/python_docs/python/api/np/arrays.ndarray.rst |  3 +-
 python/mxnet/numpy/multiarray.py                  | 35 +++++++++++++++++++++++
 tests/python/unittest/test_numpy_ndarray.py       | 32 +++++++++++++++++++++
 3 files changed, 69 insertions(+), 1 deletion(-)

diff --git a/docs/python_docs/python/api/np/arrays.ndarray.rst 
b/docs/python_docs/python/api/np/arrays.ndarray.rst
index e77d20b..522a667 100644
--- a/docs/python_docs/python/api/np/arrays.ndarray.rst
+++ b/docs/python_docs/python/api/np/arrays.ndarray.rst
@@ -512,12 +512,13 @@ Container customization: (see :ref:`Indexing 
<arrays.indexing>`)
    ndarray.__getitem__
    ndarray.__setitem__
 
-Conversion; the operations :func:`int()` and :func:`float()`.
+Conversion; the operations :func:`index()`, :func:`int()` and :func:`float()`.
 They work only on arrays that have one element in them
 and return the appropriate scalar.
 
 .. autosummary::
 
+   ndarray.__index__
    ndarray.__int__
    ndarray.__float__
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index e8fa2a7..148b129 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -31,6 +31,8 @@ except ImportError:
 from array import array as native_array
 import functools
 import ctypes
+import sys
+import datetime
 import warnings
 import numpy as _np
 from .. import _deferred_compute as dc
@@ -413,6 +415,34 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
                 return mx_np_func(*new_args, **new_kwargs)
 
 
+    def __array_namespace__(self, api_version=None):
+        """
+        Returns an object that has all the array API functions on it.
+
+        Notes
+        -----
+        This is a standard API in
+        
https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none.
+
+        Parameters
+        ----------
+        self : ndarray
+            The indexing key.
+        api_version : Optional, string
+            string representing the version of the array API specification to 
be returned, in `YYYY.MM` form.
+            If it is None, it should return the namespace corresponding to 
latest version of the array API
+            specification.
+        """
+        if api_version is not None:
+            try:
+                date = datetime.datetime.strptime(api_version, '%Y.%m')
+                if date.year != 2021:
+                    raise ValueError
+            except ValueError:
+                raise ValueError(f"Unrecognized array API version: 
{api_version!r}")
+        return sys.modules[self.__module__]
+
+
     def _get_np_basic_indexing(self, key):
         """
         This function indexes ``self`` with a tuple of `slice` objects only.
@@ -1303,6 +1333,11 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
 
     __nonzero__ = __bool__
 
+    def __index__(self):
+        if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer):
+            return self.item()
+        raise TypeError('only integer scalar arrays can be converted to a 
scalar index')
+
     def __float__(self):
         num_elements = self.size
         if num_elements != 1:
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index 559b8a5..2da60aa 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -21,6 +21,7 @@ from __future__ import division
 import itertools
 import os
 import pytest
+import operator
 import numpy as _np
 import mxnet as mx
 from mxnet import np, npx, autograd
@@ -1426,3 +1427,34 @@ def test_mixed_array_types_share_memory():
 def test_save_load_empty(tmp_path):
     mx.npx.savez(str(tmp_path / 'params.npz'))
     mx.npx.load(str(tmp_path / 'params.npz'))
+
+@use_np
[email protected]('shape', [
+    (),
+    (1,),
+    (1,2)
+])
[email protected]('dtype', ['float16', 'float32', 'float64', 'bool', 
'int32'])
+def test_index_operator(shape, dtype):
+    if len(shape) >= 1 or not _np.issubdtype(dtype, _np.integer):
+        x = np.ones(shape=shape, dtype=dtype)
+        pytest.raises(TypeError, operator.index, x)
+    else:
+        assert operator.index(np.ones(shape=shape, dtype=dtype)) == \
+            operator.index(_np.ones(shape=shape, dtype=dtype))
+
+
[email protected]('api_version, raise_exception', [
+    (None, False),
+    ('2021.10', False),
+    ('2020.09', True),
+    ('2021.24', True),
+])
+def test_array_namespace(api_version, raise_exception):
+    x = np.array([1, 2, 3], dtype="float64")
+    if raise_exception:
+        pytest.raises(ValueError, x.__array_namespace__, api_version)
+    else:
+        xp = x.__array_namespace__(api_version)
+        y = xp.array([1, 2, 3], dtype="float64")
+        assert same(x, y)

Reply via email to