This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy by this push:
new 54a42c8 [numpy] Refactor np module (example runs through) (#15055)
54a42c8 is described below
commit 54a42c8a4d77affa70d62894c0dc0d832f98bbab
Author: reminisce <[email protected]>
AuthorDate: Sun May 26 21:19:43 2019 -0700
[numpy] Refactor np module (example runs through) (#15055)
* Refactor notebook
* notebook working with hybrid block
* More refactoring
* Remove unnecessary use_np_compat
* Use class decorator to initialize numpy ndarrays in parameter.py
* Clear notebook outputs
* Improve np decorator
* Remove npe op from optimizer
* Fix CI
* Fix functools.wraps issue in Python2
* Fix ci
---
example/numpy/demo.ipynb | 257 +++++++++++++++++-----------
include/mxnet/tuple.h | 7 +
python/mxnet/__init__.py | 3 +-
python/mxnet/base.py | 165 +-----------------
python/mxnet/gluon/block.py | 6 +-
python/mxnet/gluon/parameter.py | 21 ++-
python/mxnet/gluon/utils.py | 27 +++
python/mxnet/ndarray/ndarray.py | 6 +
python/mxnet/ndarray/numpy/_op.py | 12 +-
python/mxnet/ndarray/register.py | 62 +++++--
python/mxnet/numpy/__init__.py | 2 +-
python/mxnet/numpy/multiarray.py | 106 ++++--------
python/mxnet/optimizer/optimizer.py | 32 +++-
python/mxnet/symbol/numpy/_symbol.py | 49 +-----
python/mxnet/symbol/symbol.py | 3 +-
python/mxnet/util.py | 221 ++++++++++++++++++++++++
src/operator/numpy/np_dot.cc | 34 ++--
tests/python/gpu/test_operator_gpu.py | 1 +
tests/python/unittest/test_numpy_gluon.py | 112 ++++++++++++
tests/python/unittest/test_numpy_ndarray.py | 32 ++--
tests/python/unittest/test_numpy_op.py | 5 +-
21 files changed, 711 insertions(+), 452 deletions(-)
diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb
index 7ba184d..1f06275 100644
--- a/example/numpy/demo.ipynb
+++ b/example/numpy/demo.ipynb
@@ -4,13 +4,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Fundamentals of MXNet Numpy Module\n",
+ "# Fundamentals of MXNet-NumPy Module\n",
"\n",
"## Namespaces for Imperative Programming\n",
"- `mxnet.numpy`: Regular NumPy operators\n",
"- `mxnet.numpy.random`: NumPy random operators\n",
"- `mxnet.numpy.linalg`: NumPy linear algebra operators\n",
- "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not
exist in the official NumPy\n",
+ "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not
exist in the official NumPy and some utils (e.g. context related functions).\n",
"\n",
"## Operator Namespaces for Gluon\n",
"`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and
`npe` are aliases of `numpy` and `numpy_extension`, respectively.\n",
@@ -20,7 +20,7 @@
"- `F.npe`: Operators implemented in MXNet that do not exist in official
NumPy\n",
"\n",
"## New `ndarray` and `symbol`\n",
- "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol`
(not visible to users)\n",
+ "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol`
(not directly visible to users)\n",
"- Same name as in the official NumPy package\n",
"- Dispatch convience fluent method calls to MXNet Numpy operators\n",
"- Override many convenience fluent methods that do not exist in the
official NumPy ndarray\n",
@@ -28,7 +28,19 @@
" - Indexing: `__getitem__` and `__setitem__`\n",
" - Many binary element-wise with broadcasting, not supported in
`mxnet.symbol.Symbol`\n",
" \n",
- "## Examples of ndarray and symbol Basics\n",
+ "## User Experience of Module Importing (In Progress)\n",
+ "**Legacy**\n",
+ "```python\n",
+ "import mxnet as mx\n",
+ "from mxnet import gluon\n",
+ "```\n",
+ "**Numpy**\n",
+ "```python\n",
+ "from mxnet import np, npe, gluon\n",
+ "```\n",
+ "\n",
+ " \n",
+ "## MXNet NumPy in Action\n",
"### Scalar and zero-size tensors"
]
},
@@ -41,9 +53,6 @@
"import mxnet as mx\n",
"from mxnet import numpy as np\n",
"\n",
- "# use numpy-compatible semantics\n",
- "mx.set_np_compat(True)\n",
- "\n",
"# create a scalar tensor\n",
"x = np.array(3.14)\n",
"print(x) # x is actually an ndarray, but a scalar value will be printed"
@@ -158,7 +167,63 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Binary element-wise operations with broadcasting in new and old
symbols"
+ "### There is a line between classic operators and numpy operators...\n",
+ "- Numpy operators can only accept numpy `ndarray`s/`_Symbol`s as
inputs\n",
+ "- Classic operators can only accept classic `NDArray`s/`Symbol`s as
inputs\n",
+ "- Explicit conversions must be performed if users want to leverage
operators on both sides\n",
+ "- The layer inheriting from `HybridBlock` must have the same type of
outputs, i.e., either all classic `NDArray`s or all numpy `ndarray`s, before
hybridization\n",
+ "\n",
+ "#### Imperative"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = mx.nd.ones((2, 3)) # create a classic NDArray\n",
+ "print(a)\n",
+ "out = np.sum(a) # feeding it to a numpy operator would result in failure"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = a.as_np_ndarray() # convert `a` to a numpy ndarray sharing the same
data memory\n",
+ "print(b)\n",
+ "out = np.sum(b) # feed the numpy ndarray to a numpy operator\n",
+ "print('np.sum(b) =', out)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = mx.nd.sum(b) # feeding `b` to a classic operator would reuslt in
failure"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "c = b.as_classic_ndarray() # convert `b` to a classic ndarray\n",
+ "out = mx.nd.sum(c) # feed the classic ndarray to a classic operator\n",
+ "print('mx.nd.sum(c) =', str(out))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Gluon"
]
},
{
@@ -168,19 +233,15 @@
"outputs": [],
"source": [
"from mxnet import gluon\n",
- "class TestBinaryBroadcast(gluon.HybridBlock):\n",
- " def hybrid_forward(self, F, x1, x2):\n",
- " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
- " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
- " return x1 + x2\n",
+ "class TestMultipleOutputs(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, x):\n",
+ " ret1 = F.sum(x) # a classic operator produces a classic
NDArray\n",
+ " ret2 = F.np.sum(x) # a numpy operator produces a numpy
NDArray\n",
+ " return ret1, ret2\n",
"\n",
- "net = TestBinaryBroadcast()\n",
- "x1 = mx.nd.ones((2, 1))\n",
- "x2 = mx.nd.ones((1, 3))\n",
- "print('x1 input tensor type: ', str(type(x1)))\n",
- "print('x2 input tensor type: ', str(type(x2)))\n",
- "out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
- "print(out)"
+ "net = TestMultipleOutputs()\n",
+ "net.hybridize()\n",
+ "out = net(a) # `a` is a classic NDArray and will cause an error on
`F.np.sum` which is a numpy operator"
]
},
{
@@ -189,12 +250,9 @@
"metadata": {},
"outputs": [],
"source": [
- "net.hybridize() # mark the block for execution using a computational
graph\n",
- "try:\n",
- " out = net(x1, x2) # error: old symbol `+` operation does not support
broadcasting\n",
- " assert False # should not reach here\n",
- "except mx.MXNetError:\n",
- " print(\"ERROR: cannot perform broadcast add for two symbols of
mxnet.sym.Symbol\")"
+ "net = TestMultipleOutputs() # redefine a net with no pre-built graph\n",
+ "net.hybridize()\n",
+ "out = net(b) # `b` is a numpy ndarray and will cause an error on `F.sum`
which is a classic operator"
]
},
{
@@ -203,19 +261,15 @@
"metadata": {},
"outputs": [],
"source": [
- "class TestBinaryBroadcast2(gluon.HybridBlock):\n",
- " def hybrid_forward(self, F, x1, x2):\n",
- " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
- " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
- " return x1.as_np_ndarray() + x2 # convert x1 to new numpy
ndarray/symbol\n",
- "\n",
- "net2 = TestBinaryBroadcast2()\n",
- "net2.hybridize()\n",
+ "class TestMultipleOutputs2(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, x): # x is known to be a numpy
ndarray\n",
+ " ret1 = F.sum(x.as_classic_ndarray()) # a classic operator
produces a classic NDArray\n",
+ " ret2 = F.np.sum() # a numpy operator produces a numpy NDArray\n",
+ " return ret1, ret2 # two outputs of the layer with different
types would result in failure in building the graph\n",
"\n",
- "print('x1 input tensor type: ', str(type(x1)))\n",
- "print('x2 input tensor type: ', str(type(x2)))\n",
- "out =net2(x1, x2)\n",
- "print(out)"
+ "net = TestMultipleOutputs2()\n",
+ "net.hybridize()\n",
+ "out = net(b)"
]
},
{
@@ -224,34 +278,45 @@
"metadata": {},
"outputs": [],
"source": [
- "net = TestBinaryBroadcast() # Create a new block object to clear the
graph\n",
- "net.hybridize() # mark the block for execution using a computational
graph\n",
+ "class TestMultipleOutputs3(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, x): # x is known to be a numpy
ndarray\n",
+ " ret1 = F.sum(x.as_classic_ndarray()) # a classic operator
produces a classic NDArray\n",
+ " ret2 = F.np.sum(x) # a numpy operator produces a numpy
NDArray\n",
+ " return ret1.as_np_ndarray(), ret2 # two outputs of the layer
with different types would result in failure in building the graph\n",
"\n",
- "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol
will be used in graph construction\n",
- "print('x1 input tensor type: ', str(type(x1)))\n",
- "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol
will be used in graph construction\n",
- "print('x2 input tensor type: ', str(type(x2)))\n",
- "out = net(x1, x2) # ok: `+` operation supports broadcasting for
_NumpySymbol\n",
- "print(out) # mxnet.numpy.ndarray type, because it's from a np operator"
+ "net = TestMultipleOutputs3()\n",
+ "net.hybridize()\n",
+ "out = net(b)\n",
+ "print('classic operator output: ', out[0])\n",
+ "print('numpy operator output: ', out[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## A Simple Linear Regression Model\n",
- "Let's consider a simple linear regression model as the following.\n",
- "Given dataset `{x, y}`, where `x`s represent input examples and `y`s
represent observed data, find the parameters `w1` and `w2` for the following
model.\n",
- "```\n",
- "y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n",
- "```"
+ "### Binary element-wise operations with broadcasting in new and old
symbols"
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
+ "outputs": [],
"source": [
- "### MXNet Numpy Operators in Imperative Programming"
+ "class TestBinaryBroadcast(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, x1, x2):\n",
+ " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
+ " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
+ " return x1 + x2\n",
+ "\n",
+ "net = TestBinaryBroadcast()\n",
+ "x1 = mx.nd.ones((2, 1))\n",
+ "x2 = mx.nd.ones((1, 3))\n",
+ "print('x1 input tensor type: ', str(type(x1)))\n",
+ "print('x2 input tensor type: ', str(type(x2)))\n",
+ "out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
+ "print(out)"
]
},
{
@@ -260,56 +325,41 @@
"metadata": {},
"outputs": [],
"source": [
- "import mxnet as mx\n",
- "from mxnet import numpy as np, numpy_extension as npe\n",
- "from mxnet import autograd\n",
- "\n",
- "\n",
- "# Use numpy-compatible semantics to support scalar tensors\n",
- "mx.set_np_compat(True)\n",
- "\n",
- "# N is number of examples; D_in is input dimension;\n",
- "# H is hidden dimension; D_out is output dimension.\n",
- "N, D_in, H, D_out = 64, 1000, 100, 10\n",
- "\n",
- "# Create random input and output data\n",
- "x = mx.nd.random.normal(shape=(N, D_in)).as_np_ndarray() # x is of type
mxnet.numpy.ndarray\n",
- "y = mx.nd.random.normal(shape=(N, D_out)).as_np_ndarray() # y is of type
mxnet.numpy.ndarray\n",
- "\n",
- "# Randomly initialize weights\n",
- "w1 = mx.nd.random.normal(shape=(D_in, H)).as_np_ndarray() # w1 is of
type mxnet.numpy.ndarray\n",
- "w1.attach_grad() # w1.grad is of type mxnet.numpy.ndarray\n",
- "w2 = mx.nd.random.normal(shape=(H, D_out)).as_np_ndarray() # w2 is of
type mxnet.numpy.ndarray\n",
- "w2.attach_grad() # w2.grad is of type mxnet.numpy.ndarray\n",
- "\n",
- "learning_rate = 1e-6\n",
- "\n",
- "\n",
- "for t in range(50):\n",
- " with autograd.record():\n",
- " # Forward pass: compute predicted y\n",
- " h = x.dot(w1) # equivalent to np.dot(x, w1)\n",
- " h_relu = npe.relu(h) # equivalent to mx.nd.relu(h)\n",
- " y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n",
- "\n",
- " # Compute loss\n",
- " # (y_pred - y) ** 2 calls np.ndarray.__pow__\n",
- " # sum() calls np.sum() which should return a scalar tensor\n",
- " loss = ((y_pred - y) ** 2).sum()\n",
- " # Note that the print function will invoke loss.asnumpy()\n",
- " print(t, loss) # loss is a scalar tensor of type
mxnet.numpy.ndarray\n",
- " loss.backward()\n",
+ "net.hybridize() # mark the block for execution using a computational
graph\n",
+ "try:\n",
+ " out = net(x1, x2) # error: old symbol `+` operation does not support
broadcasting\n",
+ " assert False # should not reach here\n",
+ "except mx.MXNetError:\n",
+ " print(\"ERROR: cannot perform broadcast add for two symbols of type
mx.sym.Symbol\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = TestBinaryBroadcast() # redefine a net to clear the pre-built
graph cache\n",
+ "net.hybridize()\n",
"\n",
- " # Update weights\n",
- " w1 -= learning_rate * w1.grad\n",
- " w2 -= learning_rate * w2.grad"
+ "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray\n",
+ "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray\n",
+ "print('x1 input tensor type: ', str(type(x1)))\n",
+ "print('x2 input tensor type: ', str(type(x2)))\n",
+ "out = net(x1, x2) # ok: a graph is built with numpy symbols which
supports broadcasting, because inputs are np.ndarray's, \n",
+ "print(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### MXNet Numpy Operators in Gluon `HybridBlock`"
+ "## A Simple Linear Regression Model\n",
+ "Let's consider a simple linear regression model as the following.\n",
+ "Given dataset `{x, y}`, where `x`s represent input examples and `y`s
represent observed data, find the parameters `w1` and `w2` for the following
model.\n",
+ "```\n",
+ "y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n",
+ "```"
]
},
{
@@ -319,13 +369,10 @@
"outputs": [],
"source": [
"import mxnet as mx\n",
- "from mxnet import gluon, autograd\n",
- "\n",
- "\n",
- "# Use numpy-compatible semantics to support scalar tensors\n",
- "mx.set_np_compat(True)\n",
+ "from mxnet import gluon, autograd, np\n",
"\n",
"\n",
+ "@np.use_np_compat\n",
"class LinearRegression(gluon.HybridBlock):\n",
" def __init__(self, num_input_dim=1000, num_hidden_dim=100,
num_output_dim=10):\n",
" super(LinearRegression, self).__init__()\n",
@@ -337,7 +384,7 @@
"\n",
" def hybrid_forward(self, F, x, w1, w2):\n",
" h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n",
- " h_relu = F.npe.relu(h) # equivalent to F.relu(h)\n",
+ " h_relu = F.npe.relu(h) # equivalent to F.relu(h) but generating
np.ndarray\n",
" y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n",
" return y_pred\n",
"\n",
@@ -356,7 +403,9 @@
"y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type
mxnet.numpy.ndarray\n",
"\n",
"total_loss = TotalLoss()\n",
- "trainer = gluon.Trainer(regressor.collect_params(), 'sgd',
{'learning_rate': 1e-3, 'momentum': 0.9})\n",
+ "trainer = gluon.Trainer(regressor.collect_params(),\n",
+ " 'sgd',\n",
+ " {'learning_rate': 1e-3, 'momentum': 0.9,
'allow_np': True})\n",
"\n",
"for t in range(50):\n",
" with autograd.record():\n",
diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h
index 08381e2..f018c8f 100644
--- a/include/mxnet/tuple.h
+++ b/include/mxnet/tuple.h
@@ -661,6 +661,13 @@ inline bool shape_is_known(const TShape& x) {
return true;
}
+inline bool shape_is_known(const std::vector<TShape>& shapes) {
+ for (const TShape& shape : shapes) {
+ if (!shape_is_known(shape)) return false;
+ }
+ return true;
+}
+
/*! \brief helper function to cast type of container elements */
template<typename SrcIter, typename DstIter>
inline DstIter ShapeTypeCast(const SrcIter begin,
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 8d570d5..070018c 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -23,7 +23,8 @@ from __future__ import absolute_import
from .context import Context, current_context, cpu, gpu, cpu_pinned
from . import engine
-from .base import MXNetError, is_np_compat, set_np_compat, np_compat,
use_np_compat
+from .base import MXNetError
+from .util import is_np_compat, set_np_compat, np_compat, use_np_compat
from . import base
from . import contrib
from . import ndarray
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index af4b2c5..7149d2f 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -20,7 +20,6 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
-from functools import wraps
import atexit
import ctypes
import os
@@ -31,7 +30,7 @@ import numpy as _np
from . import libinfo
-__all__ = ['MXNetError', 'is_np_compat', 'set_np_compat', 'np_compat',
'use_np_compat']
+__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
@@ -741,150 +740,6 @@ ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
-def set_np_compat(active):
- """
- Turns on/off NumPy compatibility. NumPy-compatibility is turned off by
default in backend.
-
- Parameters
- ----------
- active : bool
- Indicates whether to turn on/off NumPy compatibility.
-
- Returns
- -------
- A bool value indicating the previous state of NumPy compatibility.
- """
- prev = ctypes.c_int()
- check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active),
ctypes.byref(prev)))
- return bool(prev.value)
-
-
-def is_np_compat():
- """
- Checks whether the NumPy compatibility is currently turned on.
- NumPy-compatibility is turned off by default in backend.
-
- Returns
- -------
- A bool value indicating whether the NumPy compatibility is currently
on.
- """
- curr = ctypes.c_bool()
- check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
- return curr.value
-
-
-class _NumpyCompatibilityStateScope(object):
- """Scope for managing numpy compatibility state.
- Do not use this class directly. Use `np_compat(active)` instead.
-
- Example::
-
- with _NumpyCompatibilityStateScope(True):
- y = model(x)
- backward([y])
-
- """
- def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name
- self._enter_is_np_compat = is_np_compat
- self._prev_is_np_compat = None
-
- def __enter__(self):
- if self._enter_is_np_compat is not None:
- self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat)
-
- def __exit__(self, ptype, value, trace):
- if self._enter_is_np_compat is not None and self._prev_is_np_compat !=
self._enter_is_np_compat:
- set_np_compat(self._prev_is_np_compat)
-
-
-def np_compat(active=True):
- """Returns an activated/deactivated NumPy compatibility state scope to be
used in 'with' statement
- and captures code that needs the compatibility.
-
- Example::
-
- with mx.np_compat(active=True):
- # A scalar tensor's shape is `()`, whose `ndim` is `0`.
- scalar = mx.nd.ones(shape=())
- assert scalar.shape == ()
-
- # In NumPy compatible mode, 0 in a shape means that dimension
contains zero elements.
- data = mx.sym.var("data", shape=(0, 2, 3))
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape()
- assert arg_shapes[0] == (0, 2, 3)
- assert out_shapes[0] == (0, 2, 3)
-
- # -1 means unknown shape dimension size in the new
NumPy-compatible shape definition
- data = mx.sym.var("data", shape=(-1, 2, 3))
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] == (-1, 2, 3)
- assert out_shapes[0] == (-1, 2, 3)
-
- # When a shape is completely unknown in NumPy-compatible mode, it
is
- # represented as `None` in Python.
- data = mx.sym.var("data")
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] is None
- assert out_shapes[0] is None
-
- with mx.np_compat(active=False):
- # 0 means unknown shape dimension size in the legacy shape
definition.
- data = mx.sym.var("data", shape=(0, 2, 3))
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] == (0, 2, 3)
- assert out_shapes[0] == (0, 2, 3)
-
- # When a shape is completely unknown in the legacy mode (default),
its ndim is
- # equal to 0 and it is represented as `()` in Python.
- data = mx.sym.var("data")
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] == ()
- assert out_shapes[0] == ()
- """
- return _NumpyCompatibilityStateScope(active)
-
-
-def use_np_compat(func):
- """Wraps a function with an activated NumPy-compatibility scope. This
ensures
- that the execution of the function is guaranteed with NumPy compatible
semantics,
- such as zero-dim and zero size tensors.
-
- Example::
- import mxnet as mx
- @mx.use_np_compat
- def scalar_one():
- return mx.nd.ones(())
- print(scalar_one())
-
- Parameters
- ----------
- func : a user-provided callable function to be scoped by the NumPy
compatibility state.
-
- Returns
- -------
- Function
- A function for wrapping the user functions in the NumPy compatibility
scope.
- """
- @wraps(func)
- def _with_np_compat(*args, **kwargs):
- with np_compat(active=True):
- return func(*args, **kwargs)
-
- return _with_np_compat
-
-
-def _sanity_check_params(func_name, unsupported_params, param_dict):
- for param_name in unsupported_params:
- if param_name in param_dict:
- raise NotImplementedError("function {} does not support parameter
{}"
- .format(func_name, param_name))
-
-
_NP_OP_PREFIX = '_np_'
_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
@@ -984,21 +839,3 @@ def _init_np_op_module(root_module_name, np_module_name,
mx_module_name, make_op
function.__module__ = module_name_local
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)
-
-
-def set_module(module):
- """Decorator for overriding __module__ on a function or class.
-
- Example usage::
-
- @set_module('mxnet.numpy')
- def example():
- pass
-
- assert example.__module__ == 'numpy'
- """
- def decorator(func):
- if module is not None:
- func.__module__ = module
- return func
- return decorator
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 4f5d696..e5b5f5f 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -32,7 +32,8 @@ from ..symbol import Symbol
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
-from .utils import _indent, _brief_print_list, HookHandle,
_check_same_symbol_type
+from .utils import _indent, _brief_print_list, HookHandle
+from .utils import _check_same_symbol_type, _check_all_np_ndarrays
from .. import numpy as _mx_np
@@ -542,7 +543,8 @@ class Block(object):
for hook in self._forward_hooks.values():
hook(self, args, out)
-
+ if _mx_np.is_np_compat():
+ _check_all_np_ndarrays(_flatten(out, "output")[0])
return out
def forward(self, *args):
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index f660b97..ebcb41f 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -30,7 +30,8 @@ from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, context
from ..context import Context, cpu
from .. import autograd
-from .utils import _indent, _brief_print_list
+from .utils import _indent, _brief_print_list, shape_is_known
+from ..util import is_np_compat
# pylint: disable= invalid-name
tensor_types = (symbol.Symbol, ndarray.NDArray)
@@ -130,7 +131,6 @@ class Parameter(object):
self._grad_stype = grad_stype
self._stype = stype
-
def __repr__(self):
s = 'Parameter {name} (shape={shape}, dtype={dtype})'
return s.format(name=self.name, shape=self.shape, dtype=self.dtype)
@@ -163,9 +163,9 @@ class Parameter(object):
if self._shape is None:
self._shape = new_shape
return
-
+ unknown_dim_size = -1 if is_np_compat() else 0
assert len(self._shape) == len(new_shape) and \
- all(j in (0, i) for i, j in zip(new_shape, self._shape)), \
+ all(j in (unknown_dim_size, i) for i, j in zip(new_shape,
self._shape)), \
"Expected shape %s is incompatible with given shape %s."%(
str(new_shape), str(self._shape))
@@ -269,7 +269,7 @@ class Parameter(object):
return
init, ctx, default_init, data = self._deferred_init
self._deferred_init = ()
- assert self.shape is not None and np.prod(self.shape) > 0, \
+ assert shape_is_known(self.shape), \
"Cannot initialize Parameter '%s' because it has " \
"invalid shape: %s. Please specify in_units, " \
"in_channels, etc for `Block`s."%(
@@ -281,6 +281,9 @@ class Parameter(object):
ctx=context.cpu(), stype=self._stype)
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)
+ # TODO(junwu): use np random operators when available
+ if is_np_compat():
+ data = data.as_np_ndarray() # convert to np.ndarray
self._init_impl(data, ctx)
@@ -305,6 +308,9 @@ class Parameter(object):
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype,
ctx=i.context,
stype=self._grad_stype) for i in
self._data]
+ # TODO(junwu): use np.zeros
+ if is_np_compat():
+ self._grad = [arr.as_np_ndarray() for arr in self._grad]
autograd.mark_variables(self._check_and_get(self._data, list),
self._grad, self.grad_req)
@@ -380,7 +386,7 @@ class Parameter(object):
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
- if not self.shape or np.prod(self.shape) <= 0:
+ if not shape_is_known(self.shape):
if self._allow_deferred_init:
self._deferred_init = (init, ctx, default_init, None)
return
@@ -414,7 +420,6 @@ class Parameter(object):
raise ValueError("Cannot reset context for Parameter '%s' because
it "
"has not been initialized."%self.name)
-
def set_data(self, data):
"""Sets this parameter's value on all contexts."""
self.shape = data.shape
@@ -553,6 +558,8 @@ class Parameter(object):
self._var = symbol.var(self.name, shape=self.shape,
dtype=self.dtype,
lr_mult=self.lr_mult, wd_mult=self.wd_mult,
init=self.init, stype=self._stype)
+ if is_np_compat():
+ self._var = self._var.as_np_ndarray()
return self._var
def cast(self, dtype):
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index f953774..418cf41 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -38,6 +38,8 @@ except ImportError:
import numpy as np
from .. import ndarray
+from ..util import is_np_compat
+
def split_data(data, num_slice, batch_axis=0, even_split=True):
"""Splits an NDArray into `num_slice` slices along `batch_axis`.
@@ -435,3 +437,28 @@ def _check_same_symbol_type(symbols):
'computation graph, please convert all the numpy
symbols in the list '
'to classic symbols by calling
`as_classic_ndarray()` on each of them.')
return np_symbol if is_np_sym else classic_symbol
+
+
+def _check_all_np_ndarrays(out):
+ """Check if ndarrays in out are all np.ndarray"""
+ from ..numpy import ndarray as np_ndarray
+ assert isinstance(out, (list, tuple))
+ for array in out:
+ if not isinstance(array, np_ndarray):
+ raise TypeError('Expected np.ndarray type in output, while
received type '
+ '{}'.format(str(type(array))))
+
+
+def shape_is_known(shape):
+ """Check whether a shape is completely known w/ or w/o np semantics."""
+ if shape is None:
+ return False
+ unknown_dim_size = -1 if is_np_compat() else 0
+ if len(shape) == 0:
+ return unknown_dim_size == -1
+ for dim_size in shape:
+ if dim_size == unknown_dim_size:
+ return False
+ assert dim_size > unknown_dim_size, "shape dimension size cannot be
less than {}, while " \
+ "received
{}".format(unknown_dim_size, dim_size)
+ return True
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index c461d22..15bd92f 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -195,6 +195,12 @@ fixed-size items.
check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
return ndarray(handle=hdl, writable=self.writable)
+ def as_classic_ndarray(self):
+ """A convenience function for creating a classic ndarray from the
current
+ ndarray with zero copy. For this class, it just returns itself since
it is
+ already a classic ndarray."""
+ return self
+
@property
def _tvm_handle(self):
return self.handle.value
diff --git a/python/mxnet/ndarray/numpy/_op.py
b/python/mxnet/ndarray/numpy/_op.py
index e905fdf..725fba4 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -19,16 +19,15 @@
from __future__ import absolute_import
import numpy as _np
-from ...base import _sanity_check_params, use_np_compat, numeric_types,
set_module
+from ...base import numeric_types
+from ...util import _sanity_check_params, use_np_compat, set_module
from ...context import current_context
from . import _internal as _npi
-from ..ndarray import NDArray
__all__ = ['zeros', 'ones', 'maximum', 'minimum']
@set_module('mxnet.ndarray.numpy')
-@use_np_compat
def zeros(shape, dtype=_np.float32, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
@@ -60,7 +59,6 @@ def zeros(shape, dtype=_np.float32, **kwargs):
@set_module('mxnet.ndarray.numpy')
-@use_np_compat
def ones(shape, dtype=None, **kwargs):
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
@@ -92,6 +90,7 @@ def ones(shape, dtype=None, **kwargs):
#pylint: disable= too-many-arguments, no-member, protected-access
+@use_np_compat
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None,
out=None):
""" Helper function for element-wise operation.
The function will perform numpy-like broadcasting if needed and call
different functions.
@@ -122,6 +121,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar,
lfn_scalar, rfn_scalar=None, ou
mxnet.numpy.ndarray
result array
"""
+ from ...numpy import ndarray
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
@@ -133,7 +133,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar,
lfn_scalar, rfn_scalar=None, ou
return rfn_scalar(rhs, float(lhs), out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
- elif isinstance(rhs, NDArray):
+ elif isinstance(rhs, ndarray):
return fn_array(lhs, rhs, out=out)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
@@ -141,7 +141,6 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar,
lfn_scalar, rfn_scalar=None, ou
@set_module('mxnet.ndarray.numpy')
-@use_np_compat
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
@@ -159,7 +158,6 @@ def maximum(x1, x2, out=None):
@set_module('mxnet.ndarray.numpy')
-@use_np_compat
def minimum(x1, x2, out=None):
"""Returns element-wise minimum of the input arrays with broadcasting.
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index a285e50..e93a74c 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -25,9 +25,10 @@ from ._internal import NDArrayBase, _imperative_invoke #
pylint: disable=unused-
from ..ndarray_doc import _build_doc
from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null,
_is_np_op # pylint: disable=unused-import
+from ..util import use_np_compat # pylint: disable=unused-import
-def _verify_all_np_ndarrays(op_name, func_name, *array_list):
+def _verify_all_np_ndarrays(op_name, func_name, args, out):
"""Verify if all the arrays are numpy ndarrays.
Parameters
@@ -37,11 +38,14 @@ def _verify_all_np_ndarrays(op_name, func_name,
*array_list):
func_name : str
Operator name exposed to users. This is usually the name by stripping
off
the prefix of the full operator names registered in backend.
- array_list : list of arrays
+ args : list of arrays
+ Input ndarray arguments to be checked.
+ out : ndarray or None or list of ndarrays
+ User-provided output ndarrays.
"""
from ..numpy import ndarray as np_ndarray
- for array in array_list:
- if (array is not None) and (not isinstance(array, np_ndarray)):
+ for arr in args:
+ if (arr is not None) and (not isinstance(arr, np_ndarray)):
raise TypeError('Operator `{}` registered in backend is known as
`{}` in Python. '
'This is a numpy operator which can only accept '
'MXNet numpy ndarrays, while received a classic
ndarray. '
@@ -49,9 +53,22 @@ def _verify_all_np_ndarrays(op_name, func_name, *array_list):
'convert it to an MXNet numpy ndarray, and then
feed the converted '
'array to this operator.'
.format(op_name, func_name))
+ if out is None:
+ return
+ if not isinstance(out, (list, tuple)):
+ out = [out]
+ for arr in out:
+ if (arr is not None) and (not isinstance(arr, np_ndarray)):
+ raise TypeError('Operator `{}` registered in backend is known as
`{}` in Python. '
+ 'This is a numpy operator which can only write to
MXNet numpy '
+ 'ndarrays, while received a classic ndarray. '
+ 'Please call `as_np_ndarray()` upon the classic
ndarray to '
+ 'convert it to an MXNet numpy ndarray, and then
feed the converted '
+ 'array to this operator.'
+ .format(op_name, func_name))
-def _verify_all_classic_ndarrays(op_name, func_name, *array_list):
+def _verify_all_classic_ndarrays(op_name, func_name, args, out):
"""Verify if all the arrays are classic ndarrays.
Parameters
@@ -61,11 +78,14 @@ def _verify_all_classic_ndarrays(op_name, func_name,
*array_list):
func_name : str
Operator name exposed to users. This is usually the name by stripping
off
the prefix of the full operator names registered in backend.
- array_list : list of arrays
+ args : list of arrays
+ Input ndarray arguments to be checked.
+ out : ndarray or None or list of ndarrays
+ User-provided output ndarrays.
"""
from ..numpy import ndarray as np_ndarray
- for array in array_list:
- if (array is not None) and (isinstance(array, np_ndarray)):
+ for arr in args:
+ if (arr is not None) and (isinstance(arr, np_ndarray)):
raise TypeError('Operator `{}` registered in backend is known as
`{}` in Python. '
'This is a classic operator which can only accept '
'classic ndarrays, while received an MXNet numpy
ndarray. '
@@ -73,6 +93,19 @@ def _verify_all_classic_ndarrays(op_name, func_name,
*array_list):
'convert it to a classic ndarray, and then feed
the converted '
'array to this operator.'
.format(op_name, func_name))
+ if out is None:
+ return
+ if not isinstance(out, (list, tuple)):
+ out = [out]
+ for arr in out:
+ if (arr is not None) and (isinstance(arr, np_ndarray)):
+ raise TypeError('Operator `{}` registered in backend is known as
`{}` in Python. '
+ 'This is a classic operator which can only write
to '
+ 'classic ndarrays, while received an MXNet numpy
ndarray. '
+ 'Please call `as_classic_ndarray()` upon the numpy
ndarray to '
+ 'convert it to a classic ndarray, and then feed
the converted '
+ 'array to this operator.'
+ .format(op_name, func_name))
# pylint: disable=too-many-locals
@@ -138,6 +171,12 @@ def _generate_ndarray_function_code(handle, op_name,
func_name, signature_only=F
signature = ndsignature + signature
code = []
+ is_np_op = _is_np_op(op_name)
+ doc_str_idx = 1
+ if is_np_op:
+ doc_str_idx = 2
+ code.append("""
+@use_np_compat""")
if arr_name:
code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
@@ -187,13 +226,12 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
keys.append('%s')
vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name,
dtype_name))
- is_np_op = _is_np_op(op_name)
verify_ndarrays_fn =\
_verify_all_np_ndarrays.__name__ if is_np_op else
_verify_all_classic_ndarrays.__name__
if not signature_only:
code.append("""
- {}("{}", "{}", out, *ndargs)
- """.format(verify_ndarrays_fn, op_name, func_name))
+ {verify_fn}("{op_name}", "{func_name}", ndargs, out)
+ """.format(verify_fn=verify_ndarrays_fn, op_name=op_name,
func_name=func_name))
code.append("""
return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%(
handle.value, str(is_np_op)))
@@ -204,7 +242,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
doc_str_lines = _os.linesep+''.join([' '+s if s.strip() else s
for s in
'r"""{doc_str}"""'.format(doc_str=doc_str)
.splitlines(True)])
- code.insert(1, doc_str_lines)
+ code.insert(doc_str_idx, doc_str_lines)
return ''.join(code), doc_str
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 0f3c3c7..6d6ac6a 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -26,6 +26,6 @@ from .multiarray import * # pylint: disable=wildcard-import
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
-from ..base import use_np_compat, set_np_compat, np_compat
+from ..util import use_np_compat, set_np_compat, np_compat, is_np_compat
__all__ = []
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index dfcce0b..f5a3b83 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -28,8 +28,9 @@ import numpy as _np
from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP
from ..ndarray._internal import _set_np_ndarray_class
from . import _op as _mx_np_op
-from ..base import use_np_compat, check_call, _LIB, NDArrayHandle,
_sanity_check_params
-from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, set_module
+from ..base import check_call, _LIB, NDArrayHandle
+from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types
+from ..util import _sanity_check_params, set_module, use_np_compat
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
@@ -74,6 +75,7 @@ _set_np_ndarray_class(_np_ndarray_cls)
@set_module('mxnet.numpy') # pylint: disable=invalid-name
+@use_np_compat
class ndarray(NDArray):
"""An array object represents a multidimensional, homogeneous array of
fixed-size items.
An associated data-type object describes the format of each element in the
array
@@ -81,16 +83,24 @@ class ndarray(NDArray):
floating point number, or something else, etc.). Arrays should be
constructed using
`array`, `zeros` or `empty`. Currently, only c-contiguous arrays are
supported."""
- @use_np_compat
def __getitem__(self, item):
# TODO(junwu): make output shape of integer indexing correct
raise NotImplementedError
- @use_np_compat
def __setitem__(self, key, value):
- self.as_classic_ndarray().__setitem__(key, value)
+ if self.size == 0:
+ return
+ if self.ndim == 0:
+ if key != ():
+ raise IndexError('scalar tensor can only accept `()` as index')
+ # TODO(junwu): Better handling of this situation
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXShallowCopyNDArray(self.handle,
ctypes.byref(hdl)))
+ classic_ndarray = NDArray(handle=hdl, writable=self.writable)
+ classic_ndarray.__setitem__(slice(None), value)
+ return
+ self._as_classic_ndarray().__setitem__(key, value)
- @use_np_compat
def __add__(self, other):
"""x.__add__(y) <=> x + y"""
if isinstance(other, ndarray):
@@ -100,7 +110,6 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __iadd__(self, other):
"""x.__iadd__(y) <=> x += y"""
if not self.writable:
@@ -112,7 +121,6 @@ class ndarray(NDArray):
else:
raise TypeError('type {} is not
supported'.format(str(type(other))))
- @use_np_compat
def __sub__(self, other):
"""x.__sub__(y) <=> x - y"""
if isinstance(other, ndarray):
@@ -122,7 +130,6 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __isub__(self, other):
"""x.__isub__(y) <=> x -= y"""
if not self.writable:
@@ -134,7 +141,6 @@ class ndarray(NDArray):
else:
raise TypeError('type {} is not
supported'.format(str(type(other))))
- @use_np_compat
def __rsub__(self, other):
"""x.__rsub__(y) <=> y - x"""
if isinstance(other, ndarray):
@@ -144,7 +150,6 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __mul__(self, other):
"""x.__mul__(y) <=> x * y"""
if isinstance(other, ndarray):
@@ -154,15 +159,12 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __neg__(self):
return self.__mul__(-1.0)
- @use_np_compat
def __imul__(self, other):
raise NotImplementedError
- @use_np_compat
def __rmul__(self, other):
"""x.__rmul__(y) <=> y * x"""
return self.__mul__(other)
@@ -181,11 +183,9 @@ class ndarray(NDArray):
' module. If you are using Python3, this error
should not have'
' been encountered.')
- @use_np_compat
def __idiv__(self, other):
raise NotImplementedError
- @use_np_compat
def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
if isinstance(other, ndarray):
@@ -195,7 +195,6 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
divisor".format(str(type(other))))
- @use_np_compat
def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
if isinstance(other, ndarray):
@@ -205,11 +204,9 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
dividend".format(str(type(other))))
- @use_np_compat
def __itruediv__(self, other):
raise NotImplementedError
- @use_np_compat
def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
if isinstance(other, ndarray):
@@ -219,7 +216,6 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
if isinstance(other, ndarray):
@@ -229,11 +225,9 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __imod__(self, other):
raise NotImplementedError
- @use_np_compat
def __pow__(self, other):
"""x.__pow__(y) <=> x ** y"""
if isinstance(other, ndarray):
@@ -243,7 +237,6 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
if isinstance(other, ndarray):
@@ -253,45 +246,36 @@ class ndarray(NDArray):
else:
raise TypeError("ndarray does not support type {} as
operand".format(str(type(other))))
- @use_np_compat
def __eq__(self, other):
"""x.__eq__(y) <=> x == y"""
raise NotImplementedError
- @use_np_compat
def __hash__(self):
raise NotImplementedError
- @use_np_compat
def __ne__(self, other):
"""x.__ne__(y) <=> x != y"""
raise NotImplementedError
- @use_np_compat
def __gt__(self, other):
"""x.__gt__(y) <=> x > y"""
raise NotImplementedError
- @use_np_compat
def __ge__(self, other):
"""x.__ge__(y) <=> x >= y"""
raise NotImplementedError
- @use_np_compat
def __lt__(self, other):
"""x.__lt__(y) <=> x < y"""
raise NotImplementedError
- @use_np_compat
def __le__(self, other):
"""x.__le__(y) <=> x <= y"""
raise NotImplementedError
- @use_np_compat
def __bool__(self):
raise NotImplementedError
- @use_np_compat
def __len__(self):
"""Number of elements along the first axis."""
return self.shape[0]
@@ -329,29 +313,38 @@ class ndarray(NDArray):
return self.transpose()
# pylint: enable= invalid-name, undefined-variable
- @use_np_compat
def _slice(self, start, stop):
raise NotImplementedError
- @use_np_compat
def _at(self, idx):
raise NotImplementedError
- @use_np_compat
def all(self, axis=None, out=None, keepdims=False):
raise NotImplementedError
- @use_np_compat
def any(self, axis=None, out=None, keepdims=False):
raise NotImplementedError
- def as_classic_ndarray(self):
- """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its
fluent methods."""
+ def _as_classic_ndarray(self):
+ """This is not a user-facing API."""
hdl = NDArrayHandle()
check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
return NDArray(handle=hdl, writable=self.writable)
- @use_np_compat
+ def as_classic_ndarray(self):
+ """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its
fluent methods."""
+ if self.ndim == 0: # TODO(junwu): this costs ~10ns, can be moved to
backend
+ raise ValueError('cannot convert a scalar np.ndarray to
mx.nd.NDArray')
+ if self.size == 0: # TODO(junwu): this costs ~10ns, can be moved to
backend
+ raise ValueError('cannot convert a zero-size np.ndarray to
mx.nd.NDArray')
+ return self._as_classic_ndarray()
+
+ def as_np_ndarray(self):
+ """A convenience function for creating a numpy ndarray from the
current ndarray
+ with zero copy. For this class, it just returns itself since it's
already a
+ numpy ndarray."""
+ return self
+
def __repr__(self):
"""Returns a string representation of the array using the following
rules:
1. If the `ndarray` is a scalar tensor, only the string of the scalar
is returned.
@@ -369,7 +362,6 @@ class ndarray(NDArray):
else:
return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__,
self.shape)
- @use_np_compat
def attach_grad(self, grad_req='write'): # pylint:
disable=arguments-differ
"""Attach a gradient buffer to this ndarray, so that `backward`
can compute gradient with respect to it.
@@ -398,14 +390,12 @@ class ndarray(NDArray):
return None
return _np_ndarray_cls(hdl)
- @use_np_compat
def detach(self):
"""Returns a new ndarray, detached from the current graph."""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl)))
return _np_ndarray_cls(hdl)
- @use_np_compat
def astype(self, dtype, *args, **kwargs): # pylint:
disable=arguments-differ,unused-argument
"""
Copy of the array, cast to a specified type.
@@ -436,7 +426,6 @@ class ndarray(NDArray):
self.copyto(res)
return res
- @use_np_compat
def copyto(self, other):
"""Copies the value of this array to another array.
@@ -470,8 +459,8 @@ class ndarray(NDArray):
[ 1., 1., 1.]], dtype=float32)
"""
if isinstance(other, ndarray):
- other = other.as_classic_ndarray()
- return self.as_classic_ndarray().copyto(other).as_np_ndarray()
+ other = other._as_classic_ndarray()
+ return self._as_classic_ndarray().copyto(other).as_np_ndarray()
def asscalar(self):
raise AttributeError('mxnet.numpy.ndarray object has no attribute
as_scalar')
@@ -479,18 +468,15 @@ class ndarray(NDArray):
def as_in_context(self, context):
return super(ndarray, self).as_in_context(context).as_np_ndarray()
- @use_np_compat
def copy(self, order='C'): # pylint: disable=arguments-differ
if order != 'C':
raise NotImplementedError('ndarray.copy only supports order=\'C\',
while '
'received {}'.format(str(order)))
return super(ndarray, self).copy().as_np_ndarray()
- @use_np_compat
def dot(self, b, out=None):
return _mx_np_op.dot(self, b, out=out)
- @use_np_compat
def reshape(self, shape, order='C'): # pylint: disable=arguments-differ
"""Returns an array containing the same data with a new shape."""
if order != 'C':
@@ -530,7 +516,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
broadcast_like')
- @use_np_compat
def repeat(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`repeat`.
@@ -547,7 +532,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute pad')
- @use_np_compat
def swapaxes(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`swapaxes`.
@@ -596,7 +580,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
slice_like')
- @use_np_compat
def take(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`take`.
@@ -621,7 +604,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
pick')
- @use_np_compat
def sort(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sort`.
@@ -638,7 +620,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
topk')
- @use_np_compat
def argsort(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argsort`.
@@ -647,7 +628,6 @@ class ndarray(NDArray):
"""
raise NotImplementedError
- @use_np_compat
def argmax(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax`.
@@ -664,7 +644,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
argmax_channel')
- @use_np_compat
def argmin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmin`.
@@ -673,7 +652,6 @@ class ndarray(NDArray):
"""
raise NotImplementedError
- @use_np_compat
def clip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`clip`.
@@ -698,7 +676,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute abs')
- @use_np_compat
def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
@@ -739,7 +716,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
tile')
- @use_np_compat
def transpose(self, *axes): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`transpose`.
@@ -780,7 +756,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
diag')
- @use_np_compat
def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint:
disable=arguments-differ
"""Convenience fluent method for :py:func:`sum`.
@@ -797,7 +772,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
nansum')
- @use_np_compat
def prod(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`prod`.
@@ -814,7 +788,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
nanprod')
- @use_np_compat
def mean(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`mean`.
@@ -823,7 +796,6 @@ class ndarray(NDArray):
"""
raise NotImplementedError
- @use_np_compat
def max(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`max`.
@@ -832,7 +804,6 @@ class ndarray(NDArray):
"""
raise NotImplementedError
- @use_np_compat
def min(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`min`.
@@ -849,7 +820,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
norm')
- @use_np_compat
def round(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`round`.
@@ -1146,7 +1116,6 @@ class ndarray(NDArray):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute
softmin')
- @use_np_compat
def squeeze(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`squeeze`.
@@ -1162,12 +1131,10 @@ class ndarray(NDArray):
raise AttributeError('mxnet.numpy.ndarray object has no attribute
broadcast_like')
@property
- @use_np_compat
def shape(self):
return super(ndarray, self).shape
@property
- @use_np_compat
def ndim(self):
"""Number of array dimensions."""
return len(self.shape)
@@ -1249,7 +1216,10 @@ def array(object, dtype=None, **kwargs):
except:
raise TypeError('source array must be an array like object')
ret = empty(object.shape, dtype=dtype, ctx=ctx)
- ret[:] = object
+ if len(object.shape) == 0:
+ ret[()] = object
+ else:
+ ret[:] = object
return ret
diff --git a/python/mxnet/optimizer/optimizer.py
b/python/mxnet/optimizer/optimizer.py
index 613ae89..5878be1 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -18,6 +18,7 @@
# pylint: disable=too-many-lines
"""Weight updating functions."""
+from __future__ import absolute_import
import logging
import math
import pickle
@@ -94,7 +95,7 @@ class Optimizer(object):
def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
clip_gradient=None, learning_rate=0.01,
lr_scheduler=None, sym=None, begin_num_update=0,
- multi_precision=False, param_dict=None):
+ multi_precision=False, param_dict=None, allow_np=False):
self.rescale_grad = rescale_grad
self.lr = learning_rate
self.lr_scheduler = lr_scheduler
@@ -119,6 +120,7 @@ class Optimizer(object):
self.idx2name = param_idx2name.copy()
self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not
None else ()
self.param_dict = param_dict if param_dict else {}
+ self.allow_np = allow_np
self.set_lr_mult({})
self.set_wd_mult({})
@@ -1618,6 +1620,25 @@ class Test(Optimizer):
# backward compatibility wrapper for Optimizer.CreateOptimizer
create = Optimizer.create_optimizer # pylint: disable=invalid-name
+
+def _as_classic(a, allow_np):
+ from ..numpy import ndarray as np_ndarray
+ if isinstance(a, (tuple, list)):
+ if any(isinstance(x, np_ndarray) for x in a):
+ if allow_np:
+ return [x.as_classic_ndarray() for x in a]
+ else:
+ raise ValueError('Converting np.ndarray to mx.nd.NDArray is
not allowed')
+ else:
+ if isinstance(a, np_ndarray):
+ if allow_np:
+ return a.as_classic_ndarray()
+ else:
+ raise ValueError('Converting np.ndarray to mx.nd.NDArray is
not allowed')
+ return a
+
+
+
class Updater(object):
"""Updater for kvstore."""
def __init__(self, optimizer):
@@ -1628,14 +1649,15 @@ class Updater(object):
def __call__(self, index, grad, weight):
"""Updates weight given gradient and index."""
+ allow_np = self.optimizer.allow_np
if not isinstance(index, (list, tuple)):
indices = [index]
- grads = [grad]
- weights = [weight]
+ grads = [_as_classic(grad, allow_np)]
+ weights = [_as_classic(weight, allow_np)]
else:
indices = index
- grads = grad
- weights = weight
+ grads = _as_classic(grad, allow_np)
+ weights = _as_classic(weight, allow_np)
if weights:
self.optimizer._set_current_context(weights[0].context.device_id)
for i, idx in enumerate(indices):
diff --git a/python/mxnet/symbol/numpy/_symbol.py
b/python/mxnet/symbol/numpy/_symbol.py
index 0bbd96b..6a03cdb 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -22,8 +22,8 @@ from __future__ import absolute_import
import ctypes
import numpy as _np
from . import _op as _mx_np_op
-from ...base import _sanity_check_params, use_np_compat, check_call, _LIB,
SymbolHandle
-from ...base import numeric_types, set_module
+from ...base import _LIB, SymbolHandle, numeric_types
+from ...util import _sanity_check_params, check_call, set_module
from ...context import current_context
from ..symbol import Symbol
from .._internal import _set_np_symbol_class
@@ -43,7 +43,6 @@ class _Symbol(Symbol):
def __iter__(self):
raise AttributeError('_Symbol object has no attribute __iter__')
- @use_np_compat
def __add__(self, other):
"""x.__add__(y) <=> x + y"""
if isinstance(other, _Symbol):
@@ -54,7 +53,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __sub__(self, other):
"""x.__sub__(y) <=> x - y"""
if isinstance(other, _Symbol):
@@ -65,7 +63,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __rsub__(self, other):
"""x.__rsub__(y) <=> y - x"""
if isinstance(other, _Symbol):
@@ -76,7 +73,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __mul__(self, other):
"""x.__mul__(y) <=> x * y"""
if isinstance(other, _Symbol):
@@ -87,7 +83,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __rmul__(self, other):
"""x.__rmul__(y) <=> y * x"""
if isinstance(other, _Symbol):
@@ -112,7 +107,6 @@ class _Symbol(Symbol):
' module. If you are using Python3, this error
should not have'
' been encountered.')
- @use_np_compat
def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
if isinstance(other, _Symbol):
@@ -123,7 +117,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
if isinstance(other, _Symbol):
@@ -134,11 +127,9 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __idiv__(self, other):
raise NotImplementedError
- @use_np_compat
def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
if isinstance(other, _Symbol):
@@ -149,7 +140,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as divisor"
.format(str(type(other))))
- @use_np_compat
def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
if isinstance(other, _Symbol):
@@ -160,11 +150,9 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as dividend"
.format(str(type(other))))
- @use_np_compat
def __itruediv__(self, other):
raise NotImplementedError
- @use_np_compat
def __pow__(self, other):
"""x.__pow__(y) <=> x ** y"""
if isinstance(other, _Symbol):
@@ -175,7 +163,6 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
if isinstance(other, _Symbol):
@@ -186,41 +173,33 @@ class _Symbol(Symbol):
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
- @use_np_compat
def __neg__(self):
"""x.__neg__() <=> - x"""
return self.__mul__(-1.0)
- @use_np_compat
def __deepcopy__(self, _):
return super(_Symbol, self).as_np_ndarray()
- @use_np_compat
def __eq__(self, other):
"""x.__eq__(y) <=> x == y"""
raise NotImplementedError
- @use_np_compat
def __ne__(self, other):
"""x.__ne__(y) <=> x != y"""
raise NotImplementedError
- @use_np_compat
def __gt__(self, other):
"""x.__gt__(y) <=> x > y"""
raise NotImplementedError
- @use_np_compat
def __ge__(self, other):
"""x.__ge__(y) <=> x >= y"""
raise NotImplementedError
- @use_np_compat
def __lt__(self, other):
"""x.__lt__(y) <=> x < y"""
raise NotImplementedError
- @use_np_compat
def __le__(self, other):
"""x.__le__(y) <=> x <= y"""
raise NotImplementedError
@@ -241,15 +220,12 @@ class _Symbol(Symbol):
return self.transpose()
# pylint: enable= invalid-name, undefined-variable
- @use_np_compat
def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ
raise NotImplementedError
- @use_np_compat
def dot(self, b, out=None):
return _mx_np_op.dot(self, b, out=out)
- @use_np_compat
def reshape(self, shape, order='C'): # pylint: disable=arguments-differ
if order != 'C':
raise NotImplementedError('ndarray.copy only supports order=\'C\',
while '
@@ -288,7 +264,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute broadcast_like')
- @use_np_compat
def repeat(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`repeat`.
@@ -305,7 +280,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute pad')
- @use_np_compat
def swapaxes(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`swapaxes`.
@@ -354,7 +328,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute slice_like')
- @use_np_compat
def take(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`take`.
@@ -379,7 +352,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute pick')
- @use_np_compat
def sort(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sort`.
@@ -396,7 +368,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute topk')
- @use_np_compat
def argsort(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argsort`.
@@ -405,7 +376,6 @@ class _Symbol(Symbol):
"""
raise NotImplementedError
- @use_np_compat
def argmax(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax`.
@@ -422,7 +392,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute argmax_channel')
- @use_np_compat
def argmin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmin`.
@@ -431,7 +400,6 @@ class _Symbol(Symbol):
"""
raise NotImplementedError
- @use_np_compat
def clip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`clip`.
@@ -456,7 +424,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute abs')
- @use_np_compat
def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
@@ -497,7 +464,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute tile')
- @use_np_compat
def transpose(self, *axes): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`transpose`.
@@ -538,7 +504,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute diag')
- @use_np_compat
def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint:
disable=arguments-differ
"""Convenience fluent method for :py:func:`sum`.
@@ -555,7 +520,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute nansum')
- @use_np_compat
def prod(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`prod`.
@@ -572,7 +536,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute nanprod')
- @use_np_compat
def mean(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`mean`.
@@ -581,7 +544,6 @@ class _Symbol(Symbol):
"""
raise NotImplementedError
- @use_np_compat
def max(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`max`.
@@ -590,7 +552,6 @@ class _Symbol(Symbol):
"""
raise NotImplementedError
- @use_np_compat
def min(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`min`.
@@ -607,7 +568,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute norm')
- @use_np_compat
def round(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`round`.
@@ -904,7 +864,6 @@ class _Symbol(Symbol):
"""
raise AttributeError('_Symbol object has no attribute softmin')
- @use_np_compat
def squeeze(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`squeeze`.
@@ -921,7 +880,6 @@ class _Symbol(Symbol):
@set_module('mxnet.symbol.numpy')
-@use_np_compat
def zeros(shape, dtype=_np.float32, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
@@ -953,7 +911,6 @@ def zeros(shape, dtype=_np.float32, **kwargs):
@set_module('mxnet.symbol.numpy')
-@use_np_compat
def ones(shape, dtype=None, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
@@ -1034,13 +991,11 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar,
lfn_scalar, rfn_scalar=None, ou
@set_module('mxnet.symbol.numpy')
-@use_np_compat
def maximum(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum,
_npi.maximum_scalar, None, out)
@set_module('mxnet.symbol.numpy')
-@use_np_compat
def minimum(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum,
_npi.minimum_scalar, None, out)
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index d84a1cb..6048fc7 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -34,7 +34,7 @@ import numpy as _numpy # pylint: disable=relative-import
from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str,
c_str_array, c_handle_array
-from ..base import mx_uint, py_str, string_types, integer_types, mx_int,
is_np_compat
+from ..base import mx_uint, py_str, string_types, integer_types, mx_int
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
@@ -45,6 +45,7 @@ from ..executor import Executor
from . import _internal
from . import op
from ._internal import SymbolBase, _set_symbol_class
+from ..util import is_np_compat
__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
"ones", "full", "arange",
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index fc8d985..60b478b 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -19,6 +19,9 @@
import ctypes
import os
import sys
+import functools
+import itertools
+import inspect
from .base import _LIB, check_call
@@ -44,3 +47,221 @@ def get_gpu_memory(gpu_dev_id):
total_mem = ctypes.c_uint64(0)
check_call(_LIB.MXGetGPUMemoryInformation64(gpu_dev_id,
ctypes.byref(free_mem), ctypes.byref(total_mem)))
return free_mem.value, total_mem.value
+
+
+def set_np_compat(active):
+ """
+ Turns on/off NumPy compatibility. NumPy-compatibility is turned off by
default in backend.
+
+ Parameters
+ ----------
+ active : bool
+ Indicates whether to turn on/off NumPy compatibility.
+
+ Returns
+ -------
+ A bool value indicating the previous state of NumPy compatibility.
+ """
+ prev = ctypes.c_int()
+ check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active),
ctypes.byref(prev)))
+ return bool(prev.value)
+
+
+def is_np_compat():
+ """
+ Checks whether the NumPy compatibility is currently turned on.
+ NumPy-compatibility is turned off by default in backend.
+
+ Returns
+ -------
+ A bool value indicating whether the NumPy compatibility is currently
on.
+ """
+ curr = ctypes.c_bool()
+ check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
+ return curr.value
+
+
+class _NumpyCompatibilityStateScope(object):
+ """Scope for managing numpy compatibility state.
+ Do not use this class directly. Use `np_compat(active)` instead.
+
+ Example::
+
+ with _NumpyCompatibilityStateScope(True):
+ y = model(x)
+ backward([y])
+
+ """
+ def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name
+ self._enter_is_np_compat = is_np_compat
+ self._prev_is_np_compat = None
+
+ def __enter__(self):
+ if self._enter_is_np_compat is not None:
+ self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat)
+
+ def __exit__(self, ptype, value, trace):
+ if self._enter_is_np_compat is not None and self._prev_is_np_compat !=
self._enter_is_np_compat:
+ set_np_compat(self._prev_is_np_compat)
+
+
+def np_compat(active=True):
+ """Returns an activated/deactivated NumPy compatibility state scope to be
used in 'with' statement
+ and captures code that needs the compatibility.
+
+ Example::
+
+ with mx.np_compat(active=True):
+ # A scalar tensor's shape is `()`, whose `ndim` is `0`.
+ scalar = mx.nd.ones(shape=())
+ assert scalar.shape == ()
+
+ # In NumPy compatible mode, 0 in a shape means that dimension
contains zero elements.
+ data = mx.sym.var("data", shape=(0, 2, 3))
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape()
+ assert arg_shapes[0] == (0, 2, 3)
+ assert out_shapes[0] == (0, 2, 3)
+
+ # -1 means unknown shape dimension size in the new
NumPy-compatible shape definition
+ data = mx.sym.var("data", shape=(-1, 2, 3))
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] == (-1, 2, 3)
+ assert out_shapes[0] == (-1, 2, 3)
+
+ # When a shape is completely unknown in NumPy-compatible mode, it
is
+ # represented as `None` in Python.
+ data = mx.sym.var("data")
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] is None
+ assert out_shapes[0] is None
+
+ with mx.np_compat(active=False):
+ # 0 means unknown shape dimension size in the legacy shape
definition.
+ data = mx.sym.var("data", shape=(0, 2, 3))
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] == (0, 2, 3)
+ assert out_shapes[0] == (0, 2, 3)
+
+ # When a shape is completely unknown in the legacy mode (default),
its ndim is
+ # equal to 0 and it is represented as `()` in Python.
+ data = mx.sym.var("data")
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] == ()
+ assert out_shapes[0] == ()
+ """
+ return _NumpyCompatibilityStateScope(active)
+
+
+def wraps_safely(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS):
+ """This function is safe version of `functools.wraps` in Python2 which
skips wrapping functions
+ for the attributes that do not exist."""
+ if sys.version_info[0] > 2:
+ return functools.wraps(wrapped)
+ else:
+ return functools.wraps(wrapped,
+ assigned=itertools.ifilter(
+ functools.partial(hasattr, wrapped),
assigned))
+
+
+def use_np_compat(func):
+ """A decorator wrapping a function or class with an activated
NumPy-compatibility scope.
+ When `func` is a function, this ensures that the execution of the function
is scoped with NumPy
+ compatible-semantics, such as the support for zero-dim and zero size
tensors. When
+ `func` is a class, it ensures that all the methods, static functions, and
properties
+ of the class are executed with the NumPy-compatible semantics.
+
+ Example::
+ import mxnet as mx
+ @mx.use_np_compat
+ def scalar_one():
+ return mx.nd.ones(())
+ print(scalar_one())
+
+ @np.use_np_compat
+ class ScalarTensor(object):
+ def __init__(self, val=None):
+ if val is None:
+ val = ScalarTensor.random().value
+ self._scalar = mx.nd.ones(()) * val
+
+ def __repr__(self):
+ print("Is __repr__ numpy compatible?
{}!".format(str(np.is_np_compat())))
+ return str(self._scalar.asnumpy())
+
+ @staticmethod
+ def random():
+ val = mx.nd.random.uniform().asnumpy().item()
+ return ScalarTensor(val)
+
+ @property
+ def value(self):
+ print("Is value property numpy compatible?
{}!".format(str(np.is_np_compat())))
+ return self._scalar.asnumpy().item()
+
+
+ print("Is global scope numpy compatible?
{}!".format(str(np.is_np_compat())))
+ scalar_tensor = ScalarTensor()
+ print(scalar_tensor)
+
+ Parameters
+ ----------
+ func : a user-provided callable function or class to be scoped by the
NumPy compatibility state.
+
+ Returns
+ -------
+ Function or class
+ A function or class wrapped in the NumPy compatibility scope.
+ """
+
+ if inspect.isclass(func):
+ for name, method in inspect.getmembers(
+ func,
+ predicate=
+ lambda f: inspect.isfunction(f) or inspect.ismethod(f) or
isinstance(f, property)):
+ if isinstance(method, property):
+ setattr(func, name, property(use_np_compat(method.__get__),
+ method.__set__,
+ method.__delattr__,
+ method.__doc__))
+ else:
+ setattr(func, name, use_np_compat(method))
+ return func
+ elif callable(func):
+ @wraps_safely(func)
+ def _with_np_compat(*args, **kwargs):
+ with np_compat(active=True):
+ return func(*args, **kwargs)
+ return _with_np_compat
+ else:
+ raise TypeError('use_np_compat can only decorate classes and callable
objects, '
+ 'while received a {}'.format(str(type(func))))
+
+
+def _sanity_check_params(func_name, unsupported_params, param_dict):
+ for param_name in unsupported_params:
+ if param_name in param_dict:
+ raise NotImplementedError("function {} does not support parameter
{}"
+ .format(func_name, param_name))
+
+
+def set_module(module):
+ """Decorator for overriding __module__ on a function or class.
+
+ Example usage::
+
+ @set_module('mxnet.numpy')
+ def example():
+ pass
+
+ assert example.__module__ == 'numpy'
+ """
+ def decorator(func):
+ if module is not None:
+ func.__module__ = module
+ return func
+ return decorator
diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc
index bcb310f..992bef0 100644
--- a/src/operator/numpy/np_dot.cc
+++ b/src/operator/numpy/np_dot.cc
@@ -36,29 +36,43 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& a_shape = in_attrs->at(0);
const mxnet::TShape& b_shape = in_attrs->at(1);
- if (!shape_is_known(a_shape) || !shape_is_known(b_shape)) {
+ if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) {
return false;
}
if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
// Case 1: both 1-D arrays, inner product of vectors
- CHECK_EQ(a_shape[0], b_shape[0]);
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1));
+ SHAPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0));
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, 0));
} else if (a_shape.ndim() == 2 && b_shape.ndim() == 2) {
// Case 2: both 2-D arrays, matrix multiplication
- CHECK_EQ(a_shape[1], b_shape[0]);
- mxnet::TShape mm_shape(2, 0);
- mm_shape[0] = a_shape[0];
- mm_shape[1] = b_shape[1];
- SHAPE_ASSIGN_CHECK(*out_attrs, 0, mm_shape);
+ mxnet::TShape tmp_shape(2, -1);
+ tmp_shape[1] = b_shape[0];
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape);
+
+ tmp_shape[0] = a_shape[1];
+ tmp_shape[1] = -1;
+ SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape);
+
+ tmp_shape[0] = a_shape[0];
+ tmp_shape[1] = b_shape[1];
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, tmp_shape);
} else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) {
// Case 3 + 3.5: either of them is a scalar, just scale by one of them
mxnet::TShape oshape = (a_shape.ndim() == 0) ? b_shape : a_shape;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
} else if (b_shape.ndim() == 1) {
// Case 4: a is N-D array and b is 1-D array, sum product over the last
axis
- CHECK_EQ(a_shape[a_shape.ndim() - 1], b_shape[0]);
- mxnet::TShape out_shape(a_shape.ndim() - 1, 0);
+ TShape tmp_shape(a_shape.ndim(), -1);
+ tmp_shape[a_shape.ndim() - 1] = b_shape[0];
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape);
+
+ tmp_shape = TShape(1, -1);
+ tmp_shape[0] = a_shape[a_shape.ndim() - 1];
+ SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape);
+
+ mxnet::TShape out_shape(a_shape.ndim() - 1, -1);
for (int i = 0; i < a_shape.ndim() - 1; ++i) {
out_shape[i] = a_shape[i];
}
@@ -68,7 +82,7 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs,
// of a and the 2nd-to-last axis of b
LOG(FATAL) << "Case 5 not implemented yet...";
}
- return true;
+ return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
}
NNVM_REGISTER_OP(_np_dot)
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index 055acfd..5bde086 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -37,6 +37,7 @@ from common import run_in_spawned_process
from test_operator import *
from test_numpy_op import *
from test_numpy_ndarray import *
+from test_numpy_gluon import *
from test_optimizer import *
from test_random import *
from test_exc_handling import *
diff --git a/tests/python/unittest/test_numpy_gluon.py
b/tests/python/unittest/test_numpy_gluon.py
new file mode 100644
index 0000000..446f5b8
--- /dev/null
+++ b/tests/python/unittest/test_numpy_gluon.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import absolute_import
+from __future__ import division
+import mxnet as mx
+from mxnet import gluon, autograd, np
+
+
+def test_create_np_param():
+ M, K, N = 10, 9, 20
+
+ def check_block_params(x, TestBlock, hybridize, expected_type):
+ net = TestBlock()
+ net.initialize()
+ if hybridize:
+ net.hybridize()
+ net(x)
+ params = net.collect_params()
+ for k, v in params.items():
+ assert type(v.data()) is expected_type
+
+ class TestBlock1(gluon.HybridBlock):
+ def __init__(self):
+ super(TestBlock1, self).__init__()
+ with self.name_scope():
+ self.w = self.params.get('w', shape=(K, N),
allow_deferred_init=True)
+
+ def hybrid_forward(self, F, x, w):
+ return F.dot(x, w)
+
+ @np.use_np_compat
+ class TestBlock2(gluon.HybridBlock):
+ def __init__(self):
+ super(TestBlock2, self).__init__()
+ with self.name_scope():
+ self.w = self.params.get('w', shape=(K, N),
allow_deferred_init=True)
+
+ def hybrid_forward(self, F, x, w):
+ return F.np.dot(x, w)
+
+ x = mx.nd.random.uniform(shape=(M, K))
+ check_block_params(x, TestBlock1, False, mx.nd.NDArray)
+ check_block_params(x, TestBlock1, True, mx.nd.NDArray)
+ check_block_params(x.as_np_ndarray(), TestBlock2, False, np.ndarray)
+ check_block_params(x.as_np_ndarray(), TestBlock2, True, np.ndarray)
+
+
+def test_optimizer_with_np_ndarrays():
+ @np.use_np_compat
+ class LinearRegression(gluon.HybridBlock):
+ def __init__(self, num_input_dim=-1, num_hidden_dim=100,
num_output_dim=10):
+ super(LinearRegression, self).__init__()
+ with self.name_scope():
+ self.w1 = self.params.get('w1', shape=(num_input_dim,
num_hidden_dim),
+ allow_deferred_init=True)
+ self.w2 = self.params.get('w2', shape=(num_hidden_dim,
num_output_dim),
+ allow_deferred_init=True)
+
+ def hybrid_forward(self, F, x, w1, w2):
+ h = x.dot(w1) # equivalent to F.np.dot(x, w1)
+ h_relu = F.npe.relu(h) # equivalent to F.relu(h) but generating
np.ndarray
+ y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)
+ return y_pred
+
+ @np.use_np_compat
+ class TotalLoss(gluon.HybridBlock):
+ def hybrid_forward(self, F, pred, label):
+ return ((pred - label) ** 2).sum() # equivalent to
F.np.sum(F.np.square(pred - label))
+
+ regressor = LinearRegression()
+ regressor.initialize(mx.init.Normal())
+ regressor.hybridize()
+
+ # Create random input and output data
+ x = mx.nd.random.normal(shape=(64, 1000)).as_np_ndarray() # x is of type
mxnet.numpy.ndarray
+ regressor(x)
+ y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type
mxnet.numpy.ndarray
+
+ total_loss = TotalLoss()
+ total_loss.hybridize()
+
+ trainer = gluon.Trainer(regressor.collect_params(),
+ 'sgd',
+ {'learning_rate': 1e-3, 'momentum': 0.9,
'allow_np': True})
+
+ for t in range(5):
+ with autograd.record():
+ output = regressor(x) # output is a type of np.ndarray because
np.dot is the last op in the network
+ loss = total_loss(output, y) # loss is a scalar np.ndarray
+ loss.backward()
+ trainer.step(1)
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
diff --git a/tests/python/unittest/test_numpy_ndarray.py
b/tests/python/unittest/test_numpy_ndarray.py
index eb45234..7ffa774 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -45,9 +45,9 @@ def test_array_creation():
@with_seed()
[email protected]_np_compat
def test_zeros():
# test np.zeros in Gluon
+ @np.use_np_compat
class TestZeros(HybridBlock):
def __init__(self, shape, dtype=None):
super(TestZeros, self).__init__()
@@ -57,11 +57,13 @@ def test_zeros():
def hybrid_forward(self, F, x, *args, **kwargs):
return x + F.np.zeros(shape, dtype)
+ @np.use_np_compat
class TestZerosOutputType(HybridBlock):
def hybrid_forward(self, F, x, *args, **kwargs):
return x, F.np.zeros(shape=())
# test np.zeros in imperative
+ @np.use_np_compat
def check_zero_array_creation(shape, dtype):
np_out = _np.zeros(shape=shape, dtype=dtype)
mx_out = np.zeros(shape=shape, dtype=dtype)
@@ -93,9 +95,9 @@ def test_zeros():
@with_seed()
[email protected]_np_compat
def test_ones():
# test np.ones in Gluon
+ @np.use_np_compat
class TestOnes(HybridBlock):
def __init__(self, shape, dtype=None):
super(TestOnes, self).__init__()
@@ -105,11 +107,13 @@ def test_ones():
def hybrid_forward(self, F, x, *args, **kwargs):
return x * F.np.ones(shape, dtype)
+ @np.use_np_compat
class TestOnesOutputType(HybridBlock):
def hybrid_forward(self, F, x, *args, **kwargs):
return x, F.np.ones(shape=())
# test np.ones in imperative
+ @np.use_np_compat
def check_ones_array_creation(shape, dtype):
np_out = _np.ones(shape=shape, dtype=dtype)
mx_out = np.ones(shape=shape, dtype=dtype)
@@ -141,7 +145,6 @@ def test_ones():
@with_seed()
[email protected]_np_compat
def test_ndarray_binary_element_wise_ops():
# Cannot test operators like >, because boolean arrays are not supported
yet.
np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/':
_np.divide,
@@ -153,6 +156,7 @@ def test_ndarray_binary_element_wise_ops():
def get_np_ret(x1, x2, op):
return np_op_map[op](x1, x2)
+ @np.use_np_compat
class TestBinaryElementWiseOp(HybridBlock):
def __init__(self, op, scalar=None, reverse=False):
super(TestBinaryElementWiseOp, self).__init__()
@@ -215,6 +219,7 @@ def test_ndarray_binary_element_wise_ops():
print(self._op)
assert False
+ @np.use_np_compat
def check_binary_op_result(shape1, shape2, op, dtype=None):
if shape1 is None:
mx_input1 = abs(_np.random.uniform()) + 1
@@ -250,13 +255,6 @@ def test_ndarray_binary_element_wise_ops():
assert type(mx_out) == np.ndarray
assert np_out.shape == mx_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6,
rtol=1e-5)
-
- if mx_input1.shape == mx_input2.shape:
- # classic symbol does not support element-wise binary
broadcast.
- mx_out = get_mx_ret_classic(mx_input1, mx_input2)
- assert type(mx_out) == mx.nd.NDArray
- assert np_out.shape == mx_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6,
rtol=1e-5)
else:
get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar,
reverse=reverse)
if hybridize:
@@ -291,25 +289,18 @@ def test_ndarray_binary_element_wise_ops():
@with_seed()
def test_hybrid_block_multiple_outputs():
+ @np.use_np_compat
class TestAllNumpyOutputs(HybridBlock):
- @np.use_np_compat
def hybrid_forward(self, F, x, *args, **kwargs):
return F.npe.relu(x), F.np.sum(x)
class TestAllClassicOutputs(HybridBlock):
- @np.use_np_compat
def hybrid_forward(self, F, x, *args, **kwargs):
return F.relu(x.as_classic_ndarray()),
F.sum(x.as_classic_ndarray())
- class TestMixedTypeOutputsSuccess(HybridBlock):
- @np.use_np_compat
- def hybrid_forward(self, F, x, *args, **kwargs):
- return F.relu(x.as_classic_ndarray()).as_np_ndarray(), F.np.sum(x)
-
data_np = np.ones((2, 3))
for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray),
- (TestAllNumpyOutputs, np.ndarray),
- (TestMixedTypeOutputsSuccess,
np.ndarray)]:
+ (TestAllNumpyOutputs, np.ndarray)]:
net = block()
for hybridize in [True, False]:
if hybridize:
@@ -318,12 +309,13 @@ def test_hybrid_block_multiple_outputs():
assert type(out1) is expected_out_type
assert type(out2) is expected_out_type
+ @np.use_np_compat
class TestMixedTypeOutputsFailure(HybridBlock):
- @np.use_np_compat
def hybrid_forward(self, F, x, *args, **kwargs):
return F.relu(x.as_classic_ndarray()), F.np.sum(x)
net = TestMixedTypeOutputsFailure()
+ assert_exception(net, TypeError, data_np)
net.hybridize()
assert_exception(net, TypeError, data_np)
diff --git a/tests/python/unittest/test_numpy_op.py
b/tests/python/unittest/test_numpy_op.py
index 34b2cbe..e199392 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -27,7 +27,6 @@ from common import with_seed
import random
[email protected]_np_compat
@with_seed()
def test_np_sum():
class TestSum(HybridBlock):
@@ -88,8 +87,8 @@ def test_np_sum():
assert_almost_equal(mx_out.asnumpy(), np_out,
rtol=1e-3, atol=1e-5)
[email protected]_np_compat
@with_seed()
[email protected]_np_compat
def test_np_dot():
shapes = [
((3, 0), (0, 4)),
@@ -131,9 +130,9 @@ def test_np_dot():
assert False
[email protected]_np_compat
@with_seed()
def test_np_mean():
+ @np.use_np_compat
class TestMean(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):
super(TestMean, self).__init__()