This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new cf28b46 Add magic method `abs` to NDArray and Symbol. (#15680)
cf28b46 is described below
commit cf28b46ecb2342e3010f9ed1c6b17ee3533246f9
Author: kshitij12345 <[email protected]>
AuthorDate: Fri Aug 2 10:23:09 2019 +0530
Add magic method `abs` to NDArray and Symbol. (#15680)
* add magic method abs to ndarray
* add relevant tests
* add magic method abs to symbol
* add relevant tests
* retrigger CI
* retrigger CI
---
python/mxnet/ndarray/ndarray.py | 4 ++++
python/mxnet/symbol/symbol.py | 4 ++++
tests/python/unittest/test_ndarray.py | 9 +++++++++
tests/python/unittest/test_symbol.py | 17 ++++++++++++++++-
4 files changed, 33 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 3d8a7aa..0b7dca4 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -205,6 +205,10 @@ fixed-size items.
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
return shared_pid.value, shared_id.value, self.shape, self.dtype
+ def __abs__(self):
+ """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)"""
+ return self.abs()
+
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
return add(self, other)
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 1e2defa..6832229 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -93,6 +93,10 @@ class Symbol(SymbolBase):
"""
return (self[i] for i in range(len(self)))
+ def __abs__(self):
+ """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
+ return self.abs()
+
def __add__(self, other):
"""x.__add__(y) <=> x+y
diff --git a/tests/python/unittest/test_ndarray.py
b/tests/python/unittest/test_ndarray.py
index 56db1eb..0f154bd 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -173,6 +173,15 @@ def test_ndarray_negate():
@with_seed()
+def test_ndarray_magic_abs():
+ for dim in range(1, 7):
+ shape = rand_shape_nd(dim)
+ npy = np.random.uniform(-10, 10, shape)
+ arr = mx.nd.array(npy)
+ assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy())
+
+
+@with_seed()
def test_ndarray_reshape():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
true_res = mx.nd.arange(30) + 1
diff --git a/tests/python/unittest/test_symbol.py
b/tests/python/unittest/test_symbol.py
index 0c97c68..963b324 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -22,7 +22,7 @@ import mxnet as mx
import numpy as np
from common import assertRaises, models
from mxnet.base import NotImplementedForSymbol
-from mxnet.test_utils import discard_stderr
+from mxnet.test_utils import discard_stderr, rand_shape_nd
import pickle as pkl
def test_symbol_basic():
@@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
assert arg_shapes[1] == overwrite_shape
assert out_shapes[0] == overwrite_shape
+
+def test_symbol_magic_abs():
+ for dim in range(1, 7):
+ with mx.name.NameManager():
+ data = mx.symbol.Variable('data')
+ method = data.abs(name='abs0')
+ magic = abs(data)
+ regular = mx.symbol.abs(data, name='abs0')
+ ctx = {'ctx': mx.context.current_context(), 'data':
rand_shape_nd(dim)}
+ mx.test_utils.check_consistency(
+ [method, magic], ctx_list=[ctx, ctx])
+ mx.test_utils.check_consistency(
+ [regular, magic], ctx_list=[ctx, ctx])
+
+
def test_symbol_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose',
'sum', 'nansum', 'prod',
'nanprod', 'mean', 'max', 'min', 'reshape',
'broadcast_to', 'split',