This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 596ef3a Add nd.power and sym.pow (#14606)
596ef3a is described below
commit 596ef3a9a2dc1c59c388ef5259a9eb8a6b6a1beb
Author: Hao Jin <[email protected]>
AuthorDate: Thu Apr 11 14:53:54 2019 -0700
Add nd.power and sym.pow (#14606)
* add nd.power and sym.pow
* deprecate sym.pow, get rid of nd.pow
---
python/mxnet/symbol/symbol.py | 41 +++++++++++++++++++++++++++++++++-
tests/python/unittest/test_operator.py | 15 ++++++++-----
2 files changed, 50 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 0c0a0a1..91d4ca1 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -47,7 +47,7 @@ from . import op
from ._internal import SymbolBase, _set_symbol_class
__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
- "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones",
"full", "arange",
+ "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
"ones", "full", "arange",
"histogram", "split_v2"]
@@ -2740,6 +2740,8 @@ def pow(base, exp):
Both inputs can be Symbol or scalar number.
Broadcasting is not supported. Use `broadcast_pow` instead.
+ `sym.pow` is being deprecated, please use `sym.power` instead.
+
Parameters
---------
base : Symbol or scalar
@@ -2780,6 +2782,43 @@ def pow(base, exp):
raise TypeError('types (%s, %s) not supported' % (str(type(base)),
str(type(exp))))
+def power(base, exp):
+ """Returns element-wise result of base element raised to powers from exp
element.
+
+ Both inputs can be Symbol or scalar number.
+ Broadcasting is not supported. Use `broadcast_pow` instead.
+
+ Parameters
+ ---------
+ base : Symbol or scalar
+ The base symbol
+ exp : Symbol or scalar
+ The exponent symbol
+
+ Returns
+ -------
+ Symbol or scalar
+ The bases in x raised to the exponents in y.
+
+ Examples
+ --------
+ >>> mx.sym.power(2, 3)
+ 8
+ >>> x = mx.sym.Variable('x')
+ >>> y = mx.sym.Variable('y')
+ >>> z = mx.sym.power(x, 2)
+ >>> z.eval(x=mx.nd.array([1,2]))[0].asnumpy()
+ array([ 1., 4.], dtype=float32)
+ >>> z = mx.sym.power(3, y)
+ >>> z.eval(y=mx.nd.array([2,3]))[0].asnumpy()
+ array([ 9., 27.], dtype=float32)
+ >>> z = mx.sym.power(x, y)
+ >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([2,3]))[0].asnumpy()
+ array([ 9., 64.], dtype=float32)
+ """
+ return pow(base, exp)
+
+
# pylint: disable=no-member
# pylint: disable=redefined-builtin
def maximum(left, right):
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 17618e4..ccb351f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -698,11 +698,11 @@ def test_symbol_pow():
def test_pow_fn():
shape = (3, 4)
exp = mx.symbol.Variable("exp")
- y = mx.sym.pow(2, exp)
x = np.ones(shape)*3
- check_numeric_gradient(y, [x], numeric_eps=1E-3)
- check_symbolic_forward(y, [x], [2**x])
- check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])
+ for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]:
+ check_numeric_gradient(y, [x], numeric_eps=1E-3)
+ check_symbolic_forward(y, [x], [2**x])
+ check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])
@with_seed()
@@ -6675,7 +6675,12 @@ def test_binary_math_operators():
lambda x, y: np.power(x, y),
lambda x, y: np.power(x, y - 1.) * y,
lambda x, y: np.power(x, y) * np.log(x),
- 0.2, 5.0, -4.0, 4.0]
+ 0.2, 5.0, -4.0, 4.0],
+ 'power': [lambda x, y: mx.sym.power(x, y),
+ lambda x, y: np.power(x, y),
+ lambda x, y: np.power(x, y - 1.) * y,
+ lambda x, y: np.power(x, y) * np.log(x),
+ 0.2, 5.0, -4.0, 4.0]
}
# Loop over operators
for name, op in binary_ops.items():