sxjscience commented on a change in pull request #18251:
URL: https://github.com/apache/incubator-mxnet/pull/18251#discussion_r424810814



##########
File path: python/mxnet/util.py
##########
@@ -929,3 +938,204 @@ def default_array(source_array, ctx=None, dtype=None):
         return _mx_np.array(source_array, ctx=ctx, dtype=dtype)
     else:
         return _mx_nd.array(source_array, ctx=ctx, dtype=dtype)
+
+class _NumpyDefaultDtypeScope(object):
+    """Scope for managing NumPy default dtype semantics.
+    In NumPy default dtype semantics, default dtype is 'float64',
+    i.e. np.array([1, 2, 3]).dtype = np.float64
+    Original default dtype without this semantic is 'float32'.
+
+    Do not use this class directly. Use `np_shape(active)` instead.
+
+    Example::
+
+        with _NumpyDefaultDtypeScope(True):
+            y = model(x)
+            backward([y])
+
+    """
+    def __init__(self, is_np_default_dtype):  #pylint: 
disable=redefined-outer-name
+        self._enter_is_np_default_dtype = is_np_default_dtype
+        self._prev_is_np_default_dtype = None
+
+    def __enter__(self):
+        if self._enter_is_np_default_dtype is not None:
+            self._prev_is_np_default_dtype = 
set_np_default_dtype(self._enter_is_np_default_dtype)
+
+    def __exit__(self, ptype, value, trace):
+        if self._enter_is_np_default_dtype is not None and\
+           self._prev_is_np_default_dtype != self._enter_is_np_default_dtype:
+            set_np_default_dtype(self._prev_is_np_default_dtype)
+
+def np_default_dtype(active=True):
+    """Returns an activated/deactivated NumPy-default_dtype scope to be used 
in 'with' statement
+    and captures code that needs the NumPy default dtype semantics. i.e. 
default dtype is float64.
+
+    Please note that this is designed as an infrastructure for the incoming
+    MXNet-NumPy operators. Legacy operators registered in the modules
+    `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+    in NumPy even within this scope.
+
+    Parameters
+    ----------
+    active : bool
+        Indicates whether to activate NumPy default dtype semantics.
+
+    Returns
+    -------
+    _NumpyDefaultDtypeScope
+        A scope object for wrapping the code w/ or w/o NumPy-default_dtype 
semantics.
+
+    Example::
+
+        with mx.np_default_Dtype(active=True):
+            # Default Dtype is 'float64', consistent with offical NumPy 
behavior.
+            arr = mx.nd.array([1, 2, 3])

Review comment:
       Should we use `mx.np.array`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to