This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 75e4d1d [API NEW][ARRAY METHOD] Add __Index__() and
__array_namespace__() (#20689)
75e4d1d is described below
commit 75e4d1d3e41cc5aef9a0141e40127a75e02ccd04
Author: Zhenghui Jin <[email protected]>
AuthorDate: Mon Nov 1 12:21:41 2021 -0700
[API NEW][ARRAY METHOD] Add __Index__() and __array_namespace__() (#20689)
* [API] Add method __index__() and __array_namespace__()
* update doc
* fix lint
* add tests
* update tests
---
docs/python_docs/python/api/np/arrays.ndarray.rst | 3 +-
python/mxnet/numpy/multiarray.py | 35 +++++++++++++++++++++++
tests/python/unittest/test_numpy_ndarray.py | 32 +++++++++++++++++++++
3 files changed, 69 insertions(+), 1 deletion(-)
diff --git a/docs/python_docs/python/api/np/arrays.ndarray.rst
b/docs/python_docs/python/api/np/arrays.ndarray.rst
index e77d20b..522a667 100644
--- a/docs/python_docs/python/api/np/arrays.ndarray.rst
+++ b/docs/python_docs/python/api/np/arrays.ndarray.rst
@@ -512,12 +512,13 @@ Container customization: (see :ref:`Indexing
<arrays.indexing>`)
ndarray.__getitem__
ndarray.__setitem__
-Conversion; the operations :func:`int()` and :func:`float()`.
+Conversion; the operations :func:`index()`, :func:`int()` and :func:`float()`.
They work only on arrays that have one element in them
and return the appropriate scalar.
.. autosummary::
+ ndarray.__index__
ndarray.__int__
ndarray.__float__
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index e8fa2a7..148b129 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -31,6 +31,8 @@ except ImportError:
from array import array as native_array
import functools
import ctypes
+import sys
+import datetime
import warnings
import numpy as _np
from .. import _deferred_compute as dc
@@ -413,6 +415,34 @@ class ndarray(NDArray): # pylint: disable=invalid-name
return mx_np_func(*new_args, **new_kwargs)
+ def __array_namespace__(self, api_version=None):
+ """
+ Returns an object that has all the array API functions on it.
+
+ Notes
+ -----
+ This is a standard API in
+
https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none.
+
+ Parameters
+ ----------
+ self : ndarray
+ The indexing key.
+ api_version : Optional, string
+ string representing the version of the array API specification to
be returned, in `YYYY.MM` form.
+ If it is None, it should return the namespace corresponding to
latest version of the array API
+ specification.
+ """
+ if api_version is not None:
+ try:
+ date = datetime.datetime.strptime(api_version, '%Y.%m')
+ if date.year != 2021:
+ raise ValueError
+ except ValueError:
+ raise ValueError(f"Unrecognized array API version:
{api_version!r}")
+ return sys.modules[self.__module__]
+
+
def _get_np_basic_indexing(self, key):
"""
This function indexes ``self`` with a tuple of `slice` objects only.
@@ -1303,6 +1333,11 @@ class ndarray(NDArray): # pylint: disable=invalid-name
__nonzero__ = __bool__
+ def __index__(self):
+ if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer):
+ return self.item()
+ raise TypeError('only integer scalar arrays can be converted to a
scalar index')
+
def __float__(self):
num_elements = self.size
if num_elements != 1:
diff --git a/tests/python/unittest/test_numpy_ndarray.py
b/tests/python/unittest/test_numpy_ndarray.py
index 559b8a5..2da60aa 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -21,6 +21,7 @@ from __future__ import division
import itertools
import os
import pytest
+import operator
import numpy as _np
import mxnet as mx
from mxnet import np, npx, autograd
@@ -1426,3 +1427,34 @@ def test_mixed_array_types_share_memory():
def test_save_load_empty(tmp_path):
mx.npx.savez(str(tmp_path / 'params.npz'))
mx.npx.load(str(tmp_path / 'params.npz'))
+
+@use_np
[email protected]('shape', [
+ (),
+ (1,),
+ (1,2)
+])
[email protected]('dtype', ['float16', 'float32', 'float64', 'bool',
'int32'])
+def test_index_operator(shape, dtype):
+ if len(shape) >= 1 or not _np.issubdtype(dtype, _np.integer):
+ x = np.ones(shape=shape, dtype=dtype)
+ pytest.raises(TypeError, operator.index, x)
+ else:
+ assert operator.index(np.ones(shape=shape, dtype=dtype)) == \
+ operator.index(_np.ones(shape=shape, dtype=dtype))
+
+
[email protected]('api_version, raise_exception', [
+ (None, False),
+ ('2021.10', False),
+ ('2020.09', True),
+ ('2021.24', True),
+])
+def test_array_namespace(api_version, raise_exception):
+ x = np.array([1, 2, 3], dtype="float64")
+ if raise_exception:
+ pytest.raises(ValueError, x.__array_namespace__, api_version)
+ else:
+ xp = x.__array_namespace__(api_version)
+ y = xp.array([1, 2, 3], dtype="float64")
+ assert same(x, y)