This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7edd69e Rename np_compat to np_shape (#15063)
7edd69e is described below
commit 7edd69eb93462f040253600e8372d1ce0e859a2d
Author: reminisce <[email protected]>
AuthorDate: Sat May 25 22:21:04 2019 -0700
Rename np_compat to np_shape (#15063)
* Change np_compat to np_shape
* Fix scala
* Fix pylint
* Add examples and fix documentation
* Fix doc
* More doc
* Rename np_compat to np_shape in test_operatory.py
* Rename in ndarray.cc
---
include/mxnet/c_api.h | 6 +-
include/mxnet/imperative.h | 14 +-
python/mxnet/__init__.py | 3 +-
python/mxnet/base.py | 140 +-------------
python/mxnet/symbol/symbol.py | 5 +-
python/mxnet/util.py | 201 +++++++++++++++++++++
.../src/main/scala/org/apache/mxnet/LibInfo.scala | 4 +-
.../main/scala/org/apache/mxnet/NumpyScope.scala | 16 +-
.../src/main/scala/org/apache/mxnet/Symbol.scala | 2 +-
.../scala/org/apache/mxnet/NumpyScopeSuite.scala | 8 +-
.../main/native/org_apache_mxnet_native_c_api.cc | 12 +-
.../main/native/org_apache_mxnet_native_c_api.h | 8 +-
src/c_api/c_api.cc | 4 +-
src/c_api/c_api_executor.cc | 4 +-
src/c_api/c_api_ndarray.cc | 8 +-
src/c_api/c_api_symbolic.cc | 4 +-
src/executor/infer_graph_attr_pass.cc | 4 +-
src/imperative/imperative.cc | 4 +-
src/imperative/imperative_utils.h | 2 +-
src/ndarray/ndarray.cc | 4 +-
src/operator/tensor/init_op.h | 2 +-
tests/python/gpu/test_operator_gpu.py | 4 +-
tests/python/unittest/test_infer_shape.py | 14 +-
tests/python/unittest/test_ndarray.py | 2 +-
tests/python/unittest/test_operator.py | 30 +--
25 files changed, 285 insertions(+), 220 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 1c2300a..1b1c10e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1067,14 +1067,14 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
-MXNET_DLL int MXIsNumpyCompatible(bool* curr);
+MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
- * \param is_np_comp 1 when numpy compatibility is on, 0 when off
+ * \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
-MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev);
+MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index ad20991..a86cc08 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -98,13 +98,13 @@ class Imperative {
return old;
}
/*! brief whether numpy compatibility is on. */
- bool is_np_comp() const {
- return is_np_comp_;
+ bool is_np_shape() const {
+ return is_np_shape_;
}
/*! brief turn on or turn off numpy compatibility switch. */
- bool set_is_np_comp(bool is_np_comp) {
- bool old = is_np_comp_;
- is_np_comp_ = is_np_comp;
+ bool set_is_np_shape(bool is_np_shape) {
+ bool old = is_np_shape_;
+ is_np_shape_ = is_np_shape;
return old;
}
/*! \brief to record operator, return corresponding node. */
@@ -177,13 +177,13 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
- static thread_local bool is_np_comp_;
+ static thread_local bool is_np_shape_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
- static MX_THREAD_LOCAL bool is_np_comp_;
+ static MX_THREAD_LOCAL bool is_np_shape_;
#endif
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 79eb1f1..ab4bffd 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -23,7 +23,8 @@ from __future__ import absolute_import
from .context import Context, current_context, cpu, gpu, cpu_pinned
from . import engine
-from .base import MXNetError, is_np_compat, set_np_compat, np_compat,
use_np_compat
+from .base import MXNetError
+from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from . import base
from . import contrib
from . import ndarray
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 5341401..73fae48 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -20,7 +20,6 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
-from functools import wraps
import atexit
import ctypes
import os
@@ -31,7 +30,7 @@ import numpy as _np
from . import libinfo
-__all__ = ['MXNetError', 'is_np_compat', 'set_np_compat', 'np_compat',
'use_np_compat']
+__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
@@ -735,140 +734,3 @@ def _generate_op_module_signature(root_namespace,
module_name, op_code_gen_func)
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
-
-
-def set_np_compat(active):
- """
- Turns on/off NumPy compatibility. NumPy-compatibility is turned off by
default in backend.
-
- Parameters
- ----------
- active : bool
- Indicates whether to turn on/off NumPy compatibility.
-
- Returns
- -------
- A bool value indicating the previous state of NumPy compatibility.
- """
- prev = ctypes.c_int()
- check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active),
ctypes.byref(prev)))
- return bool(prev.value)
-
-
-def is_np_compat():
- """
- Checks whether the NumPy compatibility is currently turned on.
- NumPy-compatibility is turned off by default in backend.
-
- Returns
- -------
- A bool value indicating whether the NumPy compatibility is currently
on.
- """
- curr = ctypes.c_bool()
- check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
- return curr.value
-
-
-class _NumpyCompatibilityStateScope(object):
- """Scope for managing numpy compatibility state.
- Do not use this class directly. Use `np_compat(active)` instead.
-
- Example::
-
- with _NumpyCompatibilityStateScope(True):
- y = model(x)
- backward([y])
-
- """
- def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name
- self._enter_is_np_compat = is_np_compat
- self._prev_is_np_compat = None
-
- def __enter__(self):
- if self._enter_is_np_compat is not None:
- self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat)
-
- def __exit__(self, ptype, value, trace):
- if self._enter_is_np_compat is not None and self._prev_is_np_compat !=
self._enter_is_np_compat:
- set_np_compat(self._prev_is_np_compat)
-
-
-def np_compat(active=True):
- """Returns an activated/deactivated NumPy compatibility state scope to be
used in 'with' statement
- and captures code that needs the compatibility.
-
- Example::
-
- with mx.np_compat(active=True):
- # A scalar tensor's shape is `()`, whose `ndim` is `0`.
- scalar = mx.nd.ones(shape=())
- assert scalar.shape == ()
-
- # In NumPy compatible mode, 0 in a shape means that dimension
contains zero elements.
- data = mx.sym.var("data", shape=(0, 2, 3))
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape()
- assert arg_shapes[0] == (0, 2, 3)
- assert out_shapes[0] == (0, 2, 3)
-
- # -1 means unknown shape dimension size in the new
NumPy-compatible shape definition
- data = mx.sym.var("data", shape=(-1, 2, 3))
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] == (-1, 2, 3)
- assert out_shapes[0] == (-1, 2, 3)
-
- # When a shape is completely unknown in NumPy-compatible mode, it
is
- # represented as `None` in Python.
- data = mx.sym.var("data")
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] is None
- assert out_shapes[0] is None
-
- with mx.np_compat(active=False):
- # 0 means unknown shape dimension size in the legacy shape
definition.
- data = mx.sym.var("data", shape=(0, 2, 3))
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] == (0, 2, 3)
- assert out_shapes[0] == (0, 2, 3)
-
- # When a shape is completely unknown in the legacy mode (default),
its ndim is
- # equal to 0 and it is represented as `()` in Python.
- data = mx.sym.var("data")
- ret = mx.sym.sin(data)
- arg_shapes, out_shapes, _ = ret.infer_shape_partial()
- assert arg_shapes[0] == ()
- assert out_shapes[0] == ()
- """
- return _NumpyCompatibilityStateScope(active)
-
-
-def use_np_compat(func):
- """Wraps a function with an activated NumPy-compatibility scope. This
ensures
- that the execution of the function is guaranteed with NumPy compatible
semantics,
- such as zero-dim and zero size tensors.
-
- Example::
- import mxnet as mx
- @mx.use_np_compat
- def scalar_one():
- return mx.nd.ones(())
- print(scalar_one())
-
- Parameters
- ----------
- func : a user-provided callable function to be scoped by the NumPy
compatibility state.
-
- Returns
- -------
- Function
- A function for wrapping the user functions in the NumPy compatibility
scope.
- """
- @wraps(func)
- def _with_np_compat(*args, **kwargs):
- with np_compat(active=True):
- return func(*args, **kwargs)
-
- return _with_np_compat
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 0ea7c9f..d3cd519 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -34,7 +34,7 @@ import numpy as _numpy
from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str,
c_str_array, c_handle_array
-from ..base import mx_uint, py_str, string_types, integer_types, mx_int,
is_np_compat
+from ..base import mx_uint, py_str, string_types, integer_types, mx_int
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
@@ -45,6 +45,7 @@ from ..executor import Executor
from . import _internal
from . import op
from ._internal import SymbolBase, _set_symbol_class
+from ..util import is_np_shape
__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
@@ -1078,7 +1079,7 @@ class Symbol(SymbolBase):
arg_names = self.list_arguments()
unknowns = []
for name, shape in zip(arg_names, arg_shapes):
- if is_np_compat():
+ if is_np_shape():
shape_is_none = not shape or -1 in shape
else:
shape_is_none = not shape or 0 in shape
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index fc8d985..29f5b78 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -19,6 +19,7 @@
import ctypes
import os
import sys
+import functools
from .base import _LIB, check_call
@@ -44,3 +45,203 @@ def get_gpu_memory(gpu_dev_id):
total_mem = ctypes.c_uint64(0)
check_call(_LIB.MXGetGPUMemoryInformation64(gpu_dev_id,
ctypes.byref(free_mem), ctypes.byref(total_mem)))
return free_mem.value, total_mem.value
+
+
+def set_np_shape(active):
+ """
+ Turns on/off NumPy shape semantics, in which `()` represents the shape of
scalar tensors,
+ and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent
the shapes
+ of zero-size tensors. This is turned off by default for keeping backward
compatibility.
+
+ Please note that this is designed as an infrastructure for the incoming
+ MXNet-NumPy operators. Legacy operators registered in the modules
+ `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+ in NumPy within this semantics.
+
+ Parameters
+ ----------
+ active : bool
+ Indicates whether to turn on/off NumPy shape semantics.
+
+ Returns
+ -------
+ A bool value indicating the previous state of NumPy shape semantics.
+
+ Example
+ -------
+ >>> import mxnet as mx
+ >>> prev_state = mx.set_np_shape(True)
+ >>> print(prev_state)
+ False
+ >>> print(mx.is_np_shape())
+ True
+ """
+ prev = ctypes.c_int()
+ check_call(_LIB.MXSetIsNumpyShape(ctypes.c_int(active),
ctypes.byref(prev)))
+ return bool(prev.value)
+
+
+def is_np_shape():
+ """
+ Checks whether the NumPy shape semantics is currently turned on.
+ In NumPy shape semantics, `()` represents the shape of scalar tensors,
+ and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent
+ the shapes of zero-size tensors. This is turned off by default for keeping
+ backward compatibility.
+
+ Please note that this is designed as an infrastructure for the incoming
+ MXNet-NumPy operators. Legacy operators registered in the modules
+ `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+ in NumPy within this semantics.
+
+ Returns
+ -------
+ A bool value indicating whether the NumPy shape semantics is currently
on.
+
+ Example
+ -------
+ >>> import mxnet as mx
+ >>> prev_state = mx.set_np_shape(True)
+ >>> print(prev_state)
+ False
+ >>> print(mx.is_np_shape())
+ True
+ """
+ curr = ctypes.c_bool()
+ check_call(_LIB.MXIsNumpyShape(ctypes.byref(curr)))
+ return curr.value
+
+
+class _NumpyShapeScope(object):
+ """Scope for managing NumPy shape semantics.
+ In NumPy shape semantics, `()` represents the shape of scalar tensors,
+ and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent
+ the shapes of zero-size tensors.
+
+ Do not use this class directly. Use `np_shape(active)` instead.
+
+ Example::
+
+ with _NumpyShapeScope(True):
+ y = model(x)
+ backward([y])
+
+ """
+ def __init__(self, is_np_shape): #pylint: disable=redefined-outer-name
+ self._enter_is_np_shape = is_np_shape
+ self._prev_is_np_shape = None
+
+ def __enter__(self):
+ if self._enter_is_np_shape is not None:
+ self._prev_is_np_shape = set_np_shape(self._enter_is_np_shape)
+
+ def __exit__(self, ptype, value, trace):
+ if self._enter_is_np_shape is not None and self._prev_is_np_shape !=
self._enter_is_np_shape:
+ set_np_shape(self._prev_is_np_shape)
+
+
+def np_shape(active=True):
+ """Returns an activated/deactivated NumPy shape scope to be used in 'with'
statement
+ and captures code that needs the NumPy shape semantics, i.e. support of
scalar and
+ zero-size tensors.
+
+ Please note that this is designed as an infrastructure for the incoming
+ MXNet-NumPy operators. Legacy operators registered in the modules
+ `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+ in NumPy even within this scope.
+
+ Parameters
+ ----------
+ active : bool
+ Indicates whether to activate NumPy-shape semantics.
+
+ Returns
+ -------
+ _NumpyShapeScope
+ A scope object for wrapping the code w/ or w/o NumPy-shape semantics.
+
+ Example::
+
+ with mx.np_shape(active=True):
+ # A scalar tensor's shape is `()`, whose `ndim` is `0`.
+ scalar = mx.nd.ones(shape=())
+ assert scalar.shape == ()
+
+ # If NumPy shape semantics is enabled, 0 in a shape means that
+ # dimension contains zero elements.
+ data = mx.sym.var("data", shape=(0, 2, 3))
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape()
+ assert arg_shapes[0] == (0, 2, 3)
+ assert out_shapes[0] == (0, 2, 3)
+
+ # -1 means unknown shape dimension size in the new NumPy shape
definition
+ data = mx.sym.var("data", shape=(-1, 2, 3))
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] == (-1, 2, 3)
+ assert out_shapes[0] == (-1, 2, 3)
+
+ # When a shape is completely unknown when NumPy shape semantics is
on, it is
+ # represented as `None` in Python.
+ data = mx.sym.var("data")
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] is None
+ assert out_shapes[0] is None
+
+ with mx.np_shape(active=False):
+ # 0 means unknown shape dimension size in the legacy shape
definition.
+ data = mx.sym.var("data", shape=(0, 2, 3))
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] == (0, 2, 3)
+ assert out_shapes[0] == (0, 2, 3)
+
+ # When a shape is completely unknown in the legacy mode (default),
its ndim is
+ # equal to 0 and it is represented as `()` in Python.
+ data = mx.sym.var("data")
+ ret = mx.sym.sin(data)
+ arg_shapes, out_shapes, _ = ret.infer_shape_partial()
+ assert arg_shapes[0] == ()
+ assert out_shapes[0] == ()
+ """
+ return _NumpyShapeScope(active)
+
+
+def use_np_shape(func):
+ """Wraps a function with an activated NumPy-shape scope. This ensures
+ that the execution of the function is guaranteed with the support of
+ scalar and zero-size tensors as in NumPy.
+
+ Please note that this is designed as an infrastructure for the incoming
+ MXNet-NumPy operators. Legacy operators registered in the modules
+ `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+ in NumPy even within this scope.
+
+
+ Parameters
+ ----------
+ func : a user-provided callable function to be scoped by the NumPy-shape
semantics.
+
+ Returns
+ -------
+ Function
+ A function for wrapping the user functions in the NumPy-shape
semantics.
+
+
+ Examples
+ --------
+ >>> import mxnet as mx
+ >>> @mx.use_np_shape
+ ... def scalar_one():
+ ... return mx.nd.ones(())
+ ...
+ >>> print(scalar_one())
+ """
+ @functools.wraps(func)
+ def _with_np_shape(*args, **kwargs):
+ with np_shape(active=True):
+ return func(*args, **kwargs)
+
+ return _with_np_shape
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index aba6185..640ecf5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -350,6 +350,6 @@ private[mxnet] class LibInfo {
@native def mxDumpProfile(finished: Int): Int
// Numpy
- @native def mxIsNumpyCompatible(compatible: RefInt): Int
- @native def mxSetIsNumpyCompatible(isNpComp: Int, prev: RefInt): Int
+ @native def mxIsNumpyShape(compatible: RefInt): Int
+ @native def mxSetIsNumpyShape(isNpComp: Int, prev: RefInt): Int
}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
index d3e76f1..b63095a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
@@ -25,24 +25,24 @@ import org.apache.mxnet.Base._
* is introduced first to support zero-dim and zero-size tensors as in NumPy.
*/
object NumpyScope {
- def setNumpyCompatible(isNpComp: Boolean): Boolean = {
+ def setNumpyShape(isNpComp: Boolean): Boolean = {
val prev = new RefInt()
- checkCall(_LIB.mxSetIsNumpyCompatible(if (isNpComp) 1 else 0, prev))
+ checkCall(_LIB.mxSetIsNumpyShape(if (isNpComp) 1 else 0, prev))
if (prev.value != 0) true else false
}
- def isNumpyCompatible: Boolean = {
+ def isNumpyShape: Boolean = {
val curr = new RefInt
- checkCall(_LIB.mxIsNumpyCompatible(curr))
+ checkCall(_LIB.mxIsNumpyShape(curr))
if (curr.value != 0) true else false
}
- def enableNumpyCompatible: NumpyScope = {
+ def enableNumpyShape: NumpyScope = {
new NumpyScope(true)
}
- def disableNumpyCompatible: NumpyScope = {
+ def disableNumpyShape: NumpyScope = {
new NumpyScope(false)
}
}
@@ -51,12 +51,12 @@ class NumpyScope(var isCompatible: Boolean) {
private var prev: Boolean = false
def withScope[T](body: => T): T = {
- prev = NumpyScope.setNumpyCompatible(isCompatible)
+ prev = NumpyScope.setNumpyShape(isCompatible)
try {
body
} finally {
if (prev != isCompatible) {
- NumpyScope.setNumpyCompatible(prev)
+ NumpyScope.setNumpyShape(prev)
}
}
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 68db2b1..80f4dc9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -293,7 +293,7 @@ class Symbol private(private[mxnet] val handle:
SymbolHandle) extends NativeReso
val (argShapes, _, _) = inferShapeImpl(partial = true, keys, indPtr,
values)
val argNames = listArguments()
val unknown = (argNames zip argShapes).map { case (name, shape) =>
- val shapeIsNone = if (NumpyScope.isNumpyCompatible) {
+ val shapeIsNone = if (NumpyScope.isNumpyShape) {
shape == null || shape.toVector.contains(-1)
} else {
shape == null || shape.toVector.contains(0)
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
index bf6627a..0581a98 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
@@ -21,14 +21,14 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
class NumpyScopeSuite extends FunSuite with BeforeAndAfterAll {
test("compatible") {
- NumpyScope.enableNumpyCompatible.withScope {
- assert(NumpyScope.isNumpyCompatible === true)
+ NumpyScope.enableNumpyShape.withScope {
+ assert(NumpyScope.isNumpyShape === true)
}
}
test("incompatible") {
- NumpyScope.disableNumpyCompatible.withScope {
- assert(NumpyScope.isNumpyCompatible === false)
+ NumpyScope.disableNumpyShape.withScope {
+ assert(NumpyScope.isNumpyShape === false)
}
}
}
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index 7323d23..9b19fd3 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -2707,18 +2707,18 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxDumpProfile
}
// Numpy
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
(JNIEnv *env, jobject obj, jobject compatibleRef) {
- bool isCompatible;
- int ret = MXIsNumpyCompatible(&isCompatible);
- SetIntField(env, compatibleRef, static_cast<int>(isCompatible));
+ bool isNumpyShape;
+ int ret = MXIsNumpyShape(&isNumpyShape);
+ SetIntField(env, compatibleRef, static_cast<int>(isNumpyShape));
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape
(JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) {
int prev;
- int ret = MXSetIsNumpyCompatible(isNpComp, &prev);
+ int ret = MXSetIsNumpyShape(isNpComp, &prev);
SetIntField(env, prevRef, prev);
return ret;
}
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index 467272c..fac32bb 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -873,18 +873,18 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxDumpProfile
/*
* Class: org_apache_mxnet_LibInfo
- * Method: mxIsNumpyCompatible
+ * Method: mxIsNumpyShape
* Signature: (Lorg/apache/mxnet/Base/RefInt;)I
*/
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
(JNIEnv *, jobject, jobject);
/*
* Class: org_apache_mxnet_LibInfo
- * Method: mxSetIsNumpyCompatible
+ * Method: mxSetIsNumpyShape
* Signature: (ILorg/apache/mxnet/Base/RefInt;)I
*/
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape
(JNIEnv *, jobject, jint, jobject);
#ifdef __cplusplus
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 7f8d5f5..f5d72d5 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -521,7 +521,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
mxnet::TShape s = arr->shape();
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToLegacyShape(&s);
}
*out_dim = s.ndim();
@@ -532,7 +532,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
*out_pdata = buffer.data();
}
} else {
- if (Imperative::Get()->is_np_comp()) {
+ if (Imperative::Get()->is_np_shape()) {
*out_dim = -1;
} else {
*out_dim = 0;
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 8fade7d..ebe3f17 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -415,7 +415,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
CHECK(p.second) << "Duplicate shapes are provided for argument "
<< provided_arg_shape_names[i] << " in simple_bind";
}
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
for (auto &kv : arg_shape_map) {
common::ConvertToNumpyShape(&kv.second);
}
@@ -749,7 +749,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
CHECK(p.second) << "Duplicate shapes are provided for argument "
<< provided_arg_shape_names[i] << " in simple_bind";
}
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
for (auto &kv : arg_shape_map) {
common::ConvertToNumpyShape(&kv.second);
}
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 0e136b0..c9c6000 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -276,15 +276,15 @@ int MXAutogradSetIsRecording(int is_recording, int* prev)
{
API_END();
}
-int MXIsNumpyCompatible(bool* curr) {
+int MXIsNumpyShape(bool* curr) {
API_BEGIN();
- *curr = Imperative::Get()->is_np_comp();
+ *curr = Imperative::Get()->is_np_shape();
API_END();
}
-int MXSetIsNumpyCompatible(int is_np_comp, int* prev) {
+int MXSetIsNumpyShape(int is_np_shape, int* prev) {
API_BEGIN();
- *prev = Imperative::Get()->set_is_np_comp(static_cast<bool>(is_np_comp));
+ *prev = Imperative::Get()->set_is_np_shape(static_cast<bool>(is_np_shape));
API_END();
}
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index a3b9fce..4c6229e 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -556,7 +556,7 @@ int MXSymbolInferShape(SymbolHandle sym,
// if use legacy shape definition, need to convert numpy shape to legacy
shape
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToLegacyShape(&shapes);
}
@@ -629,7 +629,7 @@ int MXSymbolInferShapeEx(SymbolHandle sym,
// if use legacy shape definition, need to convert numpy shape to legacy
shape
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToLegacyShape(&shapes);
}
diff --git a/src/executor/infer_graph_attr_pass.cc
b/src/executor/infer_graph_attr_pass.cc
index a71e5ec..d723253 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -470,7 +470,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
std::vector<int> is_dynamic(rshape.size(), 0);
// convert to numpy compatible shape to use operator's infer shape function
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToNumpyShape(&rshape);
}
@@ -490,7 +490,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
if (it != inode.source->attrs.dict.end()) {
std::istringstream is(it->second);
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToNumpyShape(&rshape[out_ent_id]);
}
}
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index f014ab9..d8fba1c 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -25,11 +25,11 @@ namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
-thread_local bool Imperative::is_np_comp_ = false;
+thread_local bool Imperative::is_np_shape_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
-MX_THREAD_LOCAL bool Imperative::is_np_comp_ = false;
+MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false;
#endif
Imperative* Imperative::Get() {
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index 5c97068..5cb805c 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -122,7 +122,7 @@ inline void SetShapeType(const Context& ctx,
if (!infershape.count(attrs.op)) {
is_dynamic_shape_existing = true;
} else {
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToNumpyShape(&in_shapes);
common::ConvertToNumpyShape(&out_shapes);
}
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 9474d0c..16c579f 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1582,7 +1582,7 @@ static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;
void NDArray::Save(dmlc::Stream *strm) const {
// TODO(junwu): Support this after NumPy operators are merged
- CHECK(!Imperative::Get()->is_np_comp())
+ CHECK(!Imperative::Get()->is_np_shape())
<< "Saving ndarray within the scope of np_shape is not supported.";
// write magic number to mark this version
// for storage type
@@ -1702,7 +1702,7 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const
uint32_t magic) {
bool NDArray::Load(dmlc::Stream *strm) {
// TODO(junwu): Support this after NumPy operators are merged
- CHECK(!Imperative::Get()->is_np_comp())
+ CHECK(!Imperative::Get()->is_np_shape())
<< "Loading ndarray within the scope of np_shape is not supported.";
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index e4b090d..fd49153 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -242,7 +242,7 @@ inline bool InitShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape param_shape = param.shape;
- if (!Imperative::Get()->is_np_comp()) {
+ if (!Imperative::Get()->is_np_shape()) {
common::ConvertToNumpyShape(¶m_shape);
}
if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return
true;
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index 9c004cd..064f783 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -2012,14 +2012,14 @@ def test_multi_proposal_op():
# The following 2 functions launch 0-thread kernels, an error that should be
caught and signaled.
def kernel_error_check_imperative():
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
- with mx.np_compat(active=True):
+ with mx.np_shape(active=True):
a = mx.nd.array([1,2,3],ctx=mx.gpu(0))
b = mx.nd.array([],ctx=mx.gpu(0))
c = (a / b).asnumpy()
def kernel_error_check_symbolic():
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
- with mx.np_compat(active=True):
+ with mx.np_shape(active=True):
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = a / b
diff --git a/tests/python/unittest/test_infer_shape.py
b/tests/python/unittest/test_infer_shape.py
index 2bf7e8b..1312be0 100644
--- a/tests/python/unittest/test_infer_shape.py
+++ b/tests/python/unittest/test_infer_shape.py
@@ -154,7 +154,7 @@ def test_shape_completely_unknown():
assert arg_shapes[0] == ()
assert out_shapes[0] == ()
- with mx.np_compat():
+ with mx.np_shape():
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
@@ -169,7 +169,7 @@ def test_dot_partial_shape():
# batch size(first dim) of lhs unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5))
assert result_shape == [(0, 3, 5)]
- with mx.np_compat(True):
+ with mx.np_shape(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5))
assert result_shape == [(-1, 3, 5)]
@@ -184,7 +184,7 @@ def test_batch_dot_partial_shape():
# rhs second dim unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5))
assert result_shape == [()]
- with mx.np_compat(True):
+ with mx.np_shape(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
assert result_shape == [(-1, 3, 5)]
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1,
5))
@@ -198,7 +198,7 @@ def test_embedding_partial_shape():
y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10)
_, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10))
assert result_shape == [(0, 5, 10)]
- with mx.np_compat(True):
+ with mx.np_shape(True):
_, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10))
assert result_shape == [(-1, 5, 10)]
@@ -213,7 +213,7 @@ def test_transpose_partial_shape():
_, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224))
assert result == [(0, 224, 224, 3)]
- with mx.np_compat(True):
+ with mx.np_shape(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224))
assert result == [(-1, 224, 224, 3)]
@@ -225,7 +225,7 @@ def test_pick_partial_shape():
# batch size unknown
_, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,))
assert result == [(0, 3)]
- with mx.np_compat(True):
+ with mx.np_shape(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,))
assert result == [(-1, 3)]
@@ -240,7 +240,7 @@ def test_where_partial_shape():
assert result == [()]
_, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2))
assert result == [()]
- with mx.np_compat(True):
+ with mx.np_shape(True):
_, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2),
y =(-1, 2))
assert result == [None]
_, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2),
y=(2, 2))
diff --git a/tests/python/unittest/test_ndarray.py
b/tests/python/unittest/test_ndarray.py
index 8998b21..8b2a270 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -123,7 +123,7 @@ def test_ndarray_setitem():
# numpy assignment for empty axis
for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]:
if trivial_shape == tuple():
- with mx.np_compat():
+ with mx.np_shape():
x = mx.nd.zeros(trivial_shape)
else:
x = mx.nd.zeros(trivial_shape)
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 7767863..52fe69b 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4610,7 +4610,7 @@ def test_tile():
assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3))
test_normal_case()
- with mx.np_compat():
+ with mx.np_shape():
test_empty_tensor()
test_empty_reps()
test_tile_backward()
@@ -4671,7 +4671,7 @@ def test_one_hot():
test_normal_case(index_type=np.float64)
test_normal_case(index_type=np.float32)
test_normal_case(index_type=np.float16)
- with mx.np_compat():
+ with mx.np_shape():
test_empty_indices()
test_zero_depth()
@@ -7222,7 +7222,7 @@ def test_slice_partial_infer():
check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0))
check_slice_axis_partial_infer(var1, 1, 0, 5, (10, 0))
- with mx.np_compat():
+ with mx.np_shape():
var1 = mx.sym.var(name="data", shape=(-1, 20))
check_slice_partial_infer(var1, (None, None), (None, 10), [], (-1, 10))
check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2),
(-1, 5))
@@ -7247,7 +7247,7 @@ def test_float16_min_max():
@with_seed()
[email protected]_np_compat
[email protected]_np_shape
def test_zero_size_min_max():
def min():
a = mx.nd.zeros(shape=(5, 0))
@@ -8457,7 +8457,7 @@ def test_index_array():
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array],
[np.ones(expected.shape)], [np.zeros_like(input_array)])
- @mx.use_np_compat
+ @mx.use_np_shape
def test_index_array_default_zero_dim():
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)
@@ -8468,7 +8468,7 @@ def test_index_array():
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array],
[np.ones(expected.shape)], [np.zeros_like(input_array)])
- @mx.use_np_compat
+ @mx.use_np_shape
def test_index_array_default_zero_size():
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)
@@ -8492,7 +8492,7 @@ def test_index_array():
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array],
[np.ones(expected.shape)], [np.zeros_like(input_array)])
- @mx.use_np_compat
+ @mx.use_np_shape
def test_index_array_select_axes_zero_size():
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data, axes=(2, 1))
@@ -8502,7 +8502,7 @@ def test_index_array():
check_symbolic_forward(index_array, [input_array], [expected])
check_symbolic_backward(index_array, [input_array],
[np.ones(expected.shape)], [np.zeros_like(input_array)])
-
+
test_index_array_default()
test_index_array_default_zero_dim()
test_index_array_default_zero_size()
@@ -8514,7 +8514,7 @@ def test_index_array():
def test_scalar_tensor_creation():
assertRaises(MXNetError, mx.nd.zeros, shape=())
assertRaises(MXNetError, mx.nd.ones, shape=())
- with mx.np_compat():
+ with mx.np_shape():
data_mx = mx.nd.ones(shape=())
data_np = np.ones((), dtype=data_mx.dtype)
assert same(data_mx.asnumpy(), data_np)
@@ -8524,7 +8524,7 @@ def test_scalar_tensor_creation():
def test_zero_size_tensor_creation():
assertRaises(MXNetError, mx.nd.zeros, shape=(0, 1, 3, 0))
assertRaises(MXNetError, mx.nd.ones, shape=(0, 1, 3, 0))
- with mx.np_compat():
+ with mx.np_shape():
data_mx = mx.nd.ones(shape=(0, 1, 0, 4))
data_np = np.ones(shape=data_mx.shape, dtype=data_mx.dtype)
assert same(data_mx.asnumpy(), data_np)
@@ -8532,7 +8532,7 @@ def test_zero_size_tensor_creation():
@with_seed()
def test_concat_with_zero_size_tensor():
- with mx.np_compat():
+ with mx.np_shape():
data1 = mx.nd.ones((0, 8, 12))
data2 = mx.nd.ones((3, 8, 12))
data3 = mx.nd.ones((0, 8, 12))
@@ -8547,8 +8547,8 @@ def test_concat_with_zero_size_tensor():
@with_seed()
-def test_np_compat_decorator():
- @mx.use_np_compat
+def test_np_shape_decorator():
+ @mx.use_np_shape
def check_scalar_one():
"""Generate scalar one tensor"""
return mx.nd.ones(shape=())
@@ -8556,12 +8556,12 @@ def test_np_compat_decorator():
assert check_scalar_one.__doc__ == "Generate scalar one tensor"
assert check_scalar_one().shape == ()
for active in [True, False]:
- with mx.np_compat(active=active):
+ with mx.np_shape(active=active):
assert check_scalar_one.__name__ == "check_scalar_one"
assert check_scalar_one.__doc__ == "Generate scalar one tensor"
assert check_scalar_one().shape == ()
- @mx.use_np_compat
+ @mx.use_np_shape
def check_concat(shape1, shape2, axis):
data1 = mx.nd.ones(shape1)
data2 = mx.nd.ones(shape2)