This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cf28b46  Add magic method `abs` to NDArray and Symbol. (#15680)
cf28b46 is described below

commit cf28b46ecb2342e3010f9ed1c6b17ee3533246f9
Author: kshitij12345 <[email protected]>
AuthorDate: Fri Aug 2 10:23:09 2019 +0530

    Add magic method `abs` to NDArray and Symbol. (#15680)
    
    * add magic method abs to ndarray
    
    * add relevant tests
    
    * add magic method abs to symbol
    
    * add relevant tests
    
    * retrigger CI
    
    * retrigger CI
---
 python/mxnet/ndarray/ndarray.py       |  4 ++++
 python/mxnet/symbol/symbol.py         |  4 ++++
 tests/python/unittest/test_ndarray.py |  9 +++++++++
 tests/python/unittest/test_symbol.py  | 17 ++++++++++++++++-
 4 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 3d8a7aa..0b7dca4 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -205,6 +205,10 @@ fixed-size items.
             self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
         return shared_pid.value, shared_id.value, self.shape, self.dtype
 
+    def __abs__(self):
+        """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)"""
+        return self.abs()
+
     def __add__(self, other):
         """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
         return add(self, other)
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 1e2defa..6832229 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -93,6 +93,10 @@ class Symbol(SymbolBase):
         """
         return (self[i] for i in range(len(self)))
 
+    def __abs__(self):
+        """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
+        return self.abs()
+
     def __add__(self, other):
         """x.__add__(y) <=> x+y
 
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 56db1eb..0f154bd 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -173,6 +173,15 @@ def test_ndarray_negate():
 
 
 @with_seed()
+def test_ndarray_magic_abs():
+    for dim in range(1, 7):
+        shape = rand_shape_nd(dim)
+        npy = np.random.uniform(-10, 10, shape)
+        arr = mx.nd.array(npy)
+        assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy())
+
+
+@with_seed()
 def test_ndarray_reshape():
     tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
     true_res = mx.nd.arange(30) + 1
diff --git a/tests/python/unittest/test_symbol.py 
b/tests/python/unittest/test_symbol.py
index 0c97c68..963b324 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -22,7 +22,7 @@ import mxnet as mx
 import numpy as np
 from common import assertRaises, models
 from mxnet.base import NotImplementedForSymbol
-from mxnet.test_utils import discard_stderr
+from mxnet.test_utils import discard_stderr, rand_shape_nd
 import pickle as pkl
 
 def test_symbol_basic():
@@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
     assert arg_shapes[1] == overwrite_shape
     assert out_shapes[0] == overwrite_shape
 
+
+def test_symbol_magic_abs():
+    for dim in range(1, 7):
+        with mx.name.NameManager():
+            data = mx.symbol.Variable('data')
+            method = data.abs(name='abs0')
+            magic = abs(data)
+            regular = mx.symbol.abs(data, name='abs0')
+            ctx = {'ctx': mx.context.current_context(), 'data': 
rand_shape_nd(dim)}
+            mx.test_utils.check_consistency(
+                [method, magic], ctx_list=[ctx, ctx])
+            mx.test_utils.check_consistency(
+                [regular, magic], ctx_list=[ctx, ctx])
+
+
 def test_symbol_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 
'sum', 'nansum', 'prod',
                     'nanprod', 'mean', 'max', 'min', 'reshape', 
'broadcast_to', 'split',

Reply via email to