This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a1464f  [API NEW][METHOD] Add mT, permute_dims (#20688)
1a1464f is described below

commit 1a1464fb8469d4806a4b96527fd225363b529201
Author: Zhenghui Jin <[email protected]>
AuthorDate: Fri Oct 29 19:01:54 2021 -0700

    [API NEW][METHOD] Add mT, permute_dims (#20688)
    
    * Add mT, permute_dims
    
    * fix
    
    * fix lint
    
    * fix
    
    * fix tests
---
 .../gluon/probability/distributions/constraint.py  |  2 +-
 .../distributions/multivariate_normal.py           |  9 +-
 python/mxnet/numpy/linalg.py                       | 16 ++--
 python/mxnet/numpy/multiarray.py                   | 55 +++++++++++-
 tests/python/unittest/test_gluon_probability_v2.py | 14 ++-
 tests/python/unittest/test_numpy_op.py             | 99 ++++++++++++++++++++++
 6 files changed, 172 insertions(+), 23 deletions(-)

diff --git a/python/mxnet/gluon/probability/distributions/constraint.py 
b/python/mxnet/gluon/probability/distributions/constraint.py
index 5f6d59a..42ec967 100644
--- a/python/mxnet/gluon/probability/distributions/constraint.py
+++ b/python/mxnet/gluon/probability/distributions/constraint.py
@@ -461,7 +461,7 @@ class PositiveDefinite(Constraint):
                   " positive definite matrices".format(value)
         eps = 1e-5
         condition = np.all(
-            np.abs(value - np.swapaxes(value, -1, -2)) < eps, axis=-1)
+            np.abs(value - value.mT) < eps, axis=-1)
         condition = condition & (np.linalg.eigvals(value) > 0)
         _value = constraint_check()(condition, err_msg) * value
         return _value
diff --git 
a/python/mxnet/gluon/probability/distributions/multivariate_normal.py 
b/python/mxnet/gluon/probability/distributions/multivariate_normal.py
index 40e7c4c..049ddba 100644
--- a/python/mxnet/gluon/probability/distributions/multivariate_normal.py
+++ b/python/mxnet/gluon/probability/distributions/multivariate_normal.py
@@ -72,8 +72,7 @@ class MultivariateNormal(Distribution):
         L = flip(Cholesky(flip(P))).T
         """
         L_flip_inv_T = np.linalg.cholesky(np.flip(P, (-1, -2)))
-        L = np.linalg.inv(np.swapaxes(
-            np.flip(L_flip_inv_T, (-1, -2)), -1, -2))
+        L = np.linalg.inv(np.flip(L_flip_inv_T, (-1, -2)).mT)
         return L
 
     @cached_property
@@ -87,8 +86,7 @@ class MultivariateNormal(Distribution):
     def cov(self):
         # pylint: disable=method-hidden
         if 'scale_tril' in self.__dict__:
-            scale_triu = np.swapaxes(self.scale_tril, -1, -2)
-            return np.matmul(self.scale_tril, scale_triu)
+            return np.matmul(self.scale_tril, self.scale_tril.mT)
         return np.linalg.inv(self.precision)
 
     @cached_property
@@ -97,8 +95,7 @@ class MultivariateNormal(Distribution):
         if 'cov' in self.__dict__:
             return np.linalg.inv(self.cov)
         scale_tril_inv = np.linalg.inv(self.scale_tril)
-        scale_triu_inv = np.swapaxes(scale_tril_inv, -1, -2)
-        return np.matmul(scale_triu_inv, scale_tril_inv)
+        return np.matmul(scale_tril_inv.mT, scale_tril_inv)
 
     @property
     def mean(self):
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py
index 23e8a51..6f96f09 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/linalg.py
@@ -82,9 +82,9 @@ def matrix_transpose(a):
 
     Notes
     -----
-    `matrix_transpose` is an alias for `transpose`. It is a standard API in
+    `matrix_transpose` is new in array API spec:
     
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-transpose-x
-    instead of an official NumPy operator.
+    instead of an official NumPy operator. Unlike transpose, it only 
transposes the last two axes.
 
     Parameters
     ----------
@@ -103,14 +103,18 @@ def matrix_transpose(a):
     >>> x
     array([[0., 1.],
            [2., 3.]])
-    >>> np.transpose(x)
+    >>> np.linalg.matrix_transpose(x)
     array([[0., 2.],
            [1., 3.]])
     >>> x = np.ones((1, 2, 3))
-    >>> np.transpose(x, (1, 0, 2)).shape
-    (2, 1, 3)
+    >>> np.linalg.matrix_transpose(x)
+    array([[[1., 1.],
+            [1., 1.],
+            [1., 1.]]])
     """
-    return _mx_nd_np.transpose(a, axes=None)
+    if a.ndim < 2:
+        raise ValueError("x must be at least 2-dimensional for 
matrix_transpose")
+    return _mx_nd_np.swapaxes(a, -1, -2)
 
 
 def trace(a, offset=0):
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 32d7aa3..04bace5 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -81,7 +81,7 @@ __all__ = ['ndarray', 'empty', 'empty_like', 'array', 
'shape', 'median',
            'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 
'polyval', 'where', 'bincount',
            'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 
'squeeze',
            'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 
'diag', 'diagonal',
-           'positive', 'logaddexp', 'floor_divide']
+           'positive', 'logaddexp', 'floor_divide', 'permute_dims']
 
 __all__ += fallback.__all__
 
@@ -1333,9 +1333,22 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
     # pylint: disable= invalid-name, undefined-variable
     def T(self):
         """Same as self.transpose(). This always returns a copy of self."""
+        if self.ndim != 2:
+            warnings.warn('x.T requires x to have 2 dimensions. '
+                          'Use x.mT to transpose stacks of matrices and '
+                          'permute_dims() to permute dimensions.')
         return self.transpose()
     # pylint: enable= invalid-name, undefined-variable
 
+    @property
+    # pylint: disable= invalid-name, undefined-variable
+    def mT(self):
+        """Same as self.transpose(). This always returns a copy of self."""
+        if self.ndim < 2:
+            raise ValueError("x must be at least 2-dimensional for 
matrix_transpose")
+        return _mx_nd_np.swapaxes(self, -1, -2)
+    # pylint: enable= invalid-name, undefined-variable
+
     def all(self, axis=None, out=None, keepdims=False):
         return _mx_nd_np.all(self, axis=axis, out=out, keepdims=keepdims)
 
@@ -6448,6 +6461,46 @@ def transpose(a, axes=None):
 
 
 @set_module('mxnet.numpy')
+def permute_dims(a, axes=None):
+    """
+    Permute the dimensions of an array.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    axes : list of ints, optional
+        By default, reverse the dimensions,
+        otherwise permute the axes according to the values given.
+
+    Returns
+    -------
+    p : ndarray
+        a with its axes permuted.
+
+    Note
+    --------
+    `permute_dims` is a alias for `transpose`. It is a standard API in
+    
https://data-apis.org/array-api/latest/API_specification/manipulation_functions.html#permute-dims-x-axes
+    instead of an official NumPy operator.
+
+    Examples
+    --------
+    >>> x = np.arange(4).reshape((2,2))
+    >>> x
+    array([[0., 1.],
+           [2., 3.]])
+    >>> np.permute_dims(x)
+    array([[0., 2.],
+           [1., 3.]])
+    >>> x = np.ones((1, 2, 3))
+    >>> np.permute_dims(x, (1, 0, 2)).shape
+    (2, 1, 3)
+    """
+    return _mx_nd_np.transpose(a, axes)
+
+
+@set_module('mxnet.numpy')
 def repeat(a, repeats, axis=None):
     """
     Repeat elements of an array.
diff --git a/tests/python/unittest/test_gluon_probability_v2.py 
b/tests/python/unittest/test_gluon_probability_v2.py
index c25fa71..ba8de9a 100644
--- a/tests/python/unittest/test_gluon_probability_v2.py
+++ b/tests/python/unittest/test_gluon_probability_v2.py
@@ -1632,8 +1632,7 @@ def test_gluon_mvn():
         Force the precision matrix to be symmetric.
         """
         precision = np.linalg.inv(cov)
-        precision_t = np.swapaxes(precision, -1, -2)
-        return (precision + precision_t) / 2
+        return (precision + precision.mT) / 2
 
     event_shapes = [3, 5]
     loc_shapes = [(), (2,), (4, 2)]
@@ -1653,8 +1652,7 @@ def test_gluon_mvn():
                 loc.attach_grad()
                 _s.attach_grad()
                 # Full covariance matrix
-                sigma = np.matmul(_s, np.swapaxes(
-                    _s, -1, -2)) + np.eye(event_shape)
+                sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
                 cov_param = cov_func[cov_type](sigma)
                 net = TestMVN('sample', cov_type)
                 if hybridize:
@@ -1678,8 +1676,7 @@ def test_gluon_mvn():
                 loc.attach_grad()
                 _s.attach_grad()
                 # Full covariance matrix
-                sigma = np.matmul(_s, np.swapaxes(
-                    _s, -1, -2)) + np.eye(event_shape)
+                sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
                 cov_param = cov_func[cov_type](sigma)
                 net = TestMVN('log_prob', cov_type)
                 if hybridize:
@@ -1709,8 +1706,7 @@ def test_gluon_mvn():
                 loc.attach_grad()
                 _s.attach_grad()
                 # Full covariance matrix
-                sigma = np.matmul(_s, np.swapaxes(
-                    _s, -1, -2)) + np.eye(event_shape)
+                sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
                 cov_param = cov_func[cov_type](sigma)
                 net = TestMVN('entropy', cov_type)
                 if hybridize:
@@ -2093,7 +2089,7 @@ def test_gluon_kl():
     for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, 
cov_shapes, event_shapes):
         loc = np.random.randn(*(loc_shape + (event_shape,)))
         _s = np.random.randn(*(cov_shape + (event_shape, event_shape)))
-        sigma = np.matmul(_s, np.swapaxes(_s, -1, -2)) + np.eye(event_shape)
+        sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
         dist = mgp.MultivariateNormal(loc, cov=sigma)
         desired_shape = (loc + sigma[..., 0]).shape[:-1]
         _test_zero_kl(dist, desired_shape)
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 31c41de..1572061 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2581,6 +2581,67 @@ def test_np_transpose_error():
 
 
 @use_np
[email protected]('hybridize', [True, False])
[email protected]('dtype', [onp.float32, onp.float16, onp.int32])
[email protected]('data_shape,axes_workload', [
+    [(), [(), None]],
+    [(2,), [(0,), None]],
+    [(0, 2), [(0, 1), (1, 0)]],
+    [(5, 10), [(0, 1), (1, 0), None]],
+    [(8, 2, 3), [(2, 0, 1), (0, 2, 1), (0, 1, 2), (2, 1, 0), (-1, 1, 0), 
None]],
+    [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]],
+    [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]],
+    [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]],
+    [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]],
+    [(3, 4, 3, 4, 3, 2, 2), [(0, 1, 3, 2, 4, 5, 6),
+     (2, 3, 4, 1, 0, 5, 6), None]],
+    [(3, 4, 3, 4, 3, 2, 3, 2), [(0, 1, 3, 2, 4, 5, 7, 6),
+     (2, 3, 4, 1, 0, 5, 7, 6), None]],
+])
[email protected]('grad_req', ['write', 'add'])
+def test_np_permute_dims(data_shape, axes_workload, hybridize, dtype, 
grad_req):
+    def np_permute_dims_grad(out_shape, dtype, axes=None):
+        ograd = onp.ones(out_shape, dtype=dtype)
+        if axes is None or axes == ():
+            return onp.transpose(ograd, axes)
+        np_axes = onp.array(list(axes))
+        permute_dims_axes = onp.zeros_like(np_axes)
+        permute_dims_axes[np_axes] = onp.arange(len(np_axes))
+        return onp.transpose(ograd, tuple(list(permute_dims_axes)))
+
+    class TestPermuteDims(HybridBlock):
+        def __init__(self, axes=None):
+            super(TestPermuteDims, self).__init__()
+            self.axes = axes
+
+        def forward(self, a):
+            return np.permute_dims(a, self.axes)
+
+    for axes in axes_workload:
+        test_trans = TestPermuteDims(axes)
+        if hybridize:
+            test_trans.hybridize()
+        x = np.random.normal(0, 1, data_shape).astype(dtype)
+        x = x.astype(dtype)
+        x.attach_grad(grad_req=grad_req)
+        if grad_req == 'add':
+            x.grad[()] = np.random.normal(0, 1, 
x.grad.shape).astype(x.grad.dtype)
+            x_grad_np = x.grad.asnumpy()
+        np_out = onp.transpose(x.asnumpy(), axes)
+        with mx.autograd.record():
+            mx_out = test_trans(x)
+        assert mx_out.shape == np_out.shape
+        assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
+        mx_out.backward()
+        np_backward = np_permute_dims_grad(np_out.shape, dtype, axes)
+        if grad_req == 'add':
+            assert_almost_equal(x.grad.asnumpy(), np_backward + x_grad_np,
+                                rtol=1e-3, atol=1e-5, use_broadcast=False)
+        else:
+            assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, 
atol=1e-5, use_broadcast=False)
+
+
+@use_np
 def test_np_meshgrid():
     nx, ny = (4, 5)
     x = np.array(onp.linspace(0, 1, nx), dtype=np.float32)
@@ -6815,6 +6876,44 @@ def test_np_linalg_matrix_rank():
 
 
 @use_np
[email protected]('shape', [
+    (),
+    (1,),
+    (0, 1, 2),
+    (0, 1, 2),
+    (0, 1, 2),
+    (4, 5, 6, 7),
+    (4, 5, 6, 7),
+    (4, 5, 6, 7),
+])
+def test_np_linalg_matrix_transpose(shape):
+    class TestMatTranspose(HybridBlock):
+        def __init__(self):
+            super(TestMatTranspose, self).__init__()
+
+        def forward(self, x):
+            return np.linalg.matrix_transpose(x)
+
+    data_np = onp.random.uniform(size=shape)
+    data_mx = np.array(data_np, dtype=data_np.dtype)
+    if data_mx.ndim < 2:
+        assertRaises(ValueError, np.linalg.matrix_transpose, data_mx)
+        return
+    ret_np = onp.swapaxes(data_np, -1, -2)
+    ret_mx = np.linalg.matrix_transpose(data_mx)
+    assert same(ret_mx.asnumpy(), ret_np)
+
+    net = TestMatTranspose()
+    for hybrid in [False, True]:
+        if hybrid:
+            net.hybridize()
+        ret_mx = net(data_mx)
+        assert same(ret_mx.asnumpy(), ret_np)
+    
+    assert same(data_mx.mT.asnumpy(), ret_np)
+
+
+@use_np
 def test_np_linalg_pinv():
     class TestPinv(HybridBlock):
         def __init__(self, hermitian):

Reply via email to