This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 07661ae  decouple record/train and add state readers (#7356)
07661ae is described below

commit 07661ae9a627d2a90b15c04b665fdb0773920285
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Tue Aug 8 15:13:29 2017 -0700

    decouple record/train and add state readers (#7356)
    
    * decouple record/train and add state readers
    
    * update per comments
    
    * update per concensus
    
    * add API doc
    
    * fix
---
 docs/api/python/autograd.md            |  21 +++--
 include/mxnet/c_api.h                  |  12 +++
 python/mxnet/autograd.py               | 136 ++++++++++++++++++++++++---------
 python/mxnet/ndarray.py                |   7 +-
 src/c_api/c_api_ndarray.cc             |  12 +++
 tests/python/unittest/test_autograd.py |  37 ++++++++-
 6 files changed, 174 insertions(+), 51 deletions(-)

diff --git a/docs/api/python/autograd.md b/docs/api/python/autograd.md
index 440a1e4..d204a2c 100644
--- a/docs/api/python/autograd.md
+++ b/docs/api/python/autograd.md
@@ -14,19 +14,28 @@
 ## Autograd
 
 ```eval_rst
-.. currentmodule:: mxnet.autograd
-```
-
-
-```eval_rst
 .. autosummary::
     :nosignatures:
 
     record
     pause
-    mark_variables
+    train_mode
+    predict_mode
     backward
     set_training
+    is_training
     set_recording
+    is_recording
+    mark_variables
+```
+
+## API Reference
+
+<script type="text/javascript" 
src='../../_static/js/auto_module_index.js'></script>
+
+```eval_rst
+.. automodule:: mxnet.autograd
+    :members:
 ```
 
+<script>auto_index("api-reference");</script>
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index d9a5315..3b8d54c 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -566,6 +566,18 @@ MXNET_DLL int MXAutogradSetIsRecording(int is_recording, 
int* prev);
  */
 MXNET_DLL int MXAutogradSetIsTraining(int is_training, int* prev);
 /*!
+ * \brief get whether autograd recording is on
+ * \param curr returns the current status.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXAutogradIsRecording(bool* curr);
+/*!
+ * \brief get whether training mode is on
+ * \param curr returns the current status.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXAutogradIsTraining(bool* curr);
+/*!
  * \brief mark NDArrays as variables to compute gradient for autograd
  * \param num_var number of variable NDArrays
  * \param var_handles variable NDArrays
diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py
index 2f33052..2c3feab 100644
--- a/python/mxnet/autograd.py
+++ b/python/mxnet/autograd.py
@@ -10,7 +10,7 @@ from .ndarray import NDArray
 from .symbol import _GRAD_REQ_MAP
 
 
-def set_recording(is_recording):
+def set_recording(is_recording): #pylint: disable=redefined-outer-name
     """Set status to recording/not recording. When recording, graph will be 
constructed
     for gradient computation.
 
@@ -27,14 +27,14 @@ def set_recording(is_recording):
         ctypes.c_int(is_recording), ctypes.byref(prev)))
     return bool(prev.value)
 
-def set_training(is_train):
-    """Set status to training/not training. This affects ctx.is_train in 
operator
+def set_training(train_mode): #pylint: disable=redefined-outer-name
+    """Set status to training/predicting. This affects ctx.is_train in operator
     running context. For example, Dropout will drop inputs randomly when
-    is_train=True while simply passing through if is_train=False.
+    train_mode=True while simply passing through if train_mode=False.
 
     Parameters
     ----------
-    is_train: bool
+    train_mode: bool
 
     Returns
     -------
@@ -42,43 +42,70 @@ def set_training(is_train):
     """
     prev = ctypes.c_int()
     check_call(_LIB.MXAutogradSetIsTraining(
-        ctypes.c_int(is_train), ctypes.byref(prev)))
+        ctypes.c_int(train_mode), ctypes.byref(prev)))
     return bool(prev.value)
 
+def is_recording():
+    """Get status on recording/not recording.
 
-class RecordingStateScope(object):
+    Returns
+    -------
+    Current state of recording.
+    """
+    curr = ctypes.c_bool()
+    check_call(_LIB.MXAutogradIsRecording(ctypes.byref(curr)))
+    return curr.value
+
+def is_training():
+    """Get status on training/predicting.
+
+    Returns
+    -------
+    Current state of training/predicting.
+    """
+    curr = ctypes.c_bool()
+    check_call(_LIB.MXAutogradIsTraining(ctypes.byref(curr)))
+    return curr.value
+
+
+class _RecordingStateScope(object):
     """Scope for managing training state.
 
     Example::
-        with RecordingStateScope(True, True):
+
+        with _RecordingStateScope(True, True):
             y = model(x)
             backward([y])
+
     """
-    def __init__(self, enter_state, is_train):
-        self._enter_state = enter_state
-        self._enter_is_train = is_train
-        self._prev = None
-        self._prev_is_train = None
+    def __init__(self, is_record, train_mode): #pylint: 
disable=redefined-outer-name
+        self._enter_is_record = is_record
+        self._enter_train_mode = train_mode
+        self._prev_is_record = None
+        self._prev_train_mode = None
 
     def __enter__(self):
-        self._prev = set_recording(self._enter_state)
-        self._prev_is_train = set_training(self._enter_is_train)
+        if self._enter_is_record is not None:
+            self._prev_is_record = set_recording(self._enter_is_record)
+        if self._enter_train_mode is not None:
+            self._prev_train_mode = set_training(self._enter_train_mode)
 
     def __exit__(self, ptype, value, trace):
-        if self._prev != self._enter_state:
-            set_recording(self._prev)
-        if self._prev_is_train != self._enter_is_train:
-            set_training(self._prev_is_train)
+        if self._enter_is_record is not None and self._prev_is_record != 
self._enter_is_record:
+            set_recording(self._prev_is_record)
+        if self._enter_train_mode is not None and self._prev_train_mode != 
self._enter_train_mode:
+            set_training(self._prev_train_mode)
 
 
-def record(is_train=True):
-    """Returns a training scope context to be used in 'with' statement
-    and captures training code.
+def record(train_mode=True): #pylint: disable=redefined-outer-name
+    """Returns an autograd recording scope context to be used in 'with' 
statement
+    and captures code that needs gradients to be calculated.
 
-    .. note:: When forwarding with is_train=False, the corresponding backward
-              should also use is_train=False, otherwise gradient is undefined.
+    .. note:: When forwarding with train_mode=False, the corresponding backward
+              should also use train_mode=False, otherwise gradient is 
undefined.
 
     Example::
+
         with autograd.record():
             y = model(x)
             backward([y])
@@ -87,17 +114,19 @@ def record(is_train=True):
 
     Parameters
     ----------
-    is_train: bool, default True
-        Whether to do forward for training or inference.
+    train_mode: bool, default True
+        Whether the forward pass is in training or predicting mode. This 
controls the behavior
+        of some layers such as Dropout, BatchNorm.
     """
-    return RecordingStateScope(True, is_train)
+    return _RecordingStateScope(True, train_mode)
 
 
-def pause(is_train=False):
-    """Returns a testing scope context to be used in 'with' statement
-    and captures testing code.
+def pause(train_mode=False): #pylint: disable=redefined-outer-name
+    """Returns a scope context to be used in 'with' statement for codes that 
do not need
+    gradients to be calculated.
 
     Example::
+
         with autograd.record():
             y = model(x)
             backward([y])
@@ -106,10 +135,41 @@ def pause(is_train=False):
 
     Parameters
     ----------
-    is_train: bool, default False
-        Whether to do forward for training or inference.
+    train_mode: bool, default False
+        Whether to do forward for training or predicting.
+    """
+    return _RecordingStateScope(False, train_mode)
+
+
+def train_mode():
+    """Returns a scope context to be used in 'with' statement
+    in which forward pass behavior is set to training mode,
+    without changing the recording states.
+
+    Example::
+
+        y = model(x)
+        with autograd.train_mode():
+            y = dropout(y)
+
+    """
+    return _RecordingStateScope(None, True)
+
+
+def predict_mode():
+    """Returns a scope context to be used in 'with' statement
+    in which forward pass behavior is set to inference mode,
+    without changing the recording states.
+
+    Example::
+
+        with autograd.record():
+            y = model(x)
+            with autograd.predict_mode():
+                y = sampling(y)
+            backward([y])
     """
-    return RecordingStateScope(False, is_train)
+    return _RecordingStateScope(None, False)
 
 
 def mark_variables(variables, gradients, grad_reqs='write'):
@@ -143,7 +203,7 @@ def mark_variables(variables, gradients, grad_reqs='write'):
         c_array(NDArrayHandle, gradient_handles)))
 
 
-def backward(heads, head_grads=None, retain_graph=False, is_train=True):
+def backward(heads, head_grads=None, retain_graph=False, train_mode=True): 
#pylint: disable=redefined-outer-name
     """Compute the gradients of heads w.r.t previously marked variables.
 
     Parameters
@@ -152,8 +212,8 @@ def backward(heads, head_grads=None, retain_graph=False, 
is_train=True):
         Output NDArray(s)
     head_grads: NDArray or list of NDArray or None
         Gradients with respect to heads.
-    is_train: bool, optional
-        Whether to do backward for training or inference.
+    train_mode: bool, optional
+        Whether to do backward for training or predicting.
     """
     if isinstance(heads, NDArray):
         assert head_grads is None or isinstance(head_grads, NDArray)
@@ -170,7 +230,7 @@ def backward(heads, head_grads=None, retain_graph=False, 
is_train=True):
             c_array(NDArrayHandle, output_handles),
             ctypes.c_void_p(0),
             ctypes.c_int(retain_graph),
-            ctypes.c_int(is_train)))
+            ctypes.c_int(train_mode)))
         return
 
     ograd_handles = []
@@ -187,4 +247,4 @@ def backward(heads, head_grads=None, retain_graph=False, 
is_train=True):
         c_array(NDArrayHandle, output_handles),
         c_array(NDArrayHandle, ograd_handles),
         ctypes.c_int(retain_graph),
-        ctypes.c_int(is_train)))
+        ctypes.c_int(train_mode)))
diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py
index b2178a9..d4a0cdb 100644
--- a/python/mxnet/ndarray.py
+++ b/python/mxnet/ndarray.py
@@ -1059,7 +1059,7 @@ fixed-size items.
         check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl)))
         return NDArray(hdl)
 
-    def backward(self, out_grad=None, retain_graph=False, is_train=True):
+    def backward(self, out_grad=None, retain_graph=False, train_mode=True):
         """Compute the gradients of this NDArray w.r.t variables.
 
         Parameters
@@ -1070,7 +1070,7 @@ fixed-size items.
             Whether to retain the computaion graph for another backward
             pass on the same graph. By default the computaion history
             is cleared.
-        is_train : bool, optional
+        train_mode : bool, optional
             Whether to compute gradient for training or inference.
         """
         if out_grad is None:
@@ -1082,7 +1082,7 @@ fixed-size items.
             1, c_array(NDArrayHandle, [self.handle]),
             c_array(NDArrayHandle, ograd_handles),
             ctypes.c_int(retain_graph),
-            ctypes.c_int(is_train)))
+            ctypes.c_int(train_mode)))
 
 
 def onehot_encode(indices, out):
@@ -2538,7 +2538,6 @@ def _make_ndarray_function(handle, name):
         else:
             signature.append('%s=_Null'%name)
             kwarg_names.append(name)
-    #signature.append('is_train=False')
     signature.append('out=None')
     signature.append('name=None')
     signature.append('**kwargs')
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index f401394..a37e314 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -522,12 +522,24 @@ int MXInvokeCachedOp(CachedOpHandle handle,
   API_END();
 }
 
+int MXAutogradIsTraining(bool* curr) {
+  API_BEGIN();
+  *curr = AutogradRuntime::Get()->IsTraining();
+  API_END();
+}
+
 int MXAutogradSetIsTraining(int is_training, int* prev) {
   API_BEGIN();
   *prev = 
AutogradRuntime::Get()->SetIsTraining(static_cast<bool>(is_training));
   API_END();
 }
 
+int MXAutogradIsRecording(bool* curr) {
+  API_BEGIN();
+  *curr = AutogradRuntime::Get()->IsRecording();
+  API_END();
+}
+
 int MXAutogradSetIsRecording(int is_recording, int* prev) {
   API_BEGIN();
   *prev = 
AutogradRuntime::Get()->SetIsRecording(static_cast<bool>(is_recording));
diff --git a/tests/python/unittest/test_autograd.py 
b/tests/python/unittest/test_autograd.py
index 172075d..7ee3500 100644
--- a/tests/python/unittest/test_autograd.py
+++ b/tests/python/unittest/test_autograd.py
@@ -251,18 +251,49 @@ def test_attach_grad():
 def test_is_train():
     x = mx.nd.ones((10, 10))
     x.attach_grad()
-    with record(True):
+    with record(train_mode=True):
+        assert is_recording()
+        assert is_training()
         y = mx.nd.Dropout(x, p=0.5)
         assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0
         y.backward()
         assert (x.grad.asnumpy() == y.asnumpy()).all()
 
-    with record(False):
+        with predict_mode():
+            assert is_recording()
+            assert not is_training()
+            y = mx.nd.Dropout(x, p=0.5)
+            assert (y.asnumpy() == x.asnumpy()).all()
+            y.backward(train_mode=False)
+            assert (x.grad.asnumpy() == x.asnumpy()).all()
+
+    with record(train_mode=False):
+        assert is_recording()
+        assert not is_training()
         y = mx.nd.Dropout(x, p=0.5)
         assert (y.asnumpy() == x.asnumpy()).all()
-        y.backward(is_train=False)
+        y.backward(train_mode=False)
         assert (x.grad.asnumpy() == x.asnumpy()).all()
 
+        with train_mode():
+            assert is_recording()
+            assert is_training()
+            y = mx.nd.Dropout(x, p=0.5)
+            assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0
+            y.backward()
+            assert (x.grad.asnumpy() == y.asnumpy()).all()
+
+    assert not is_recording()
+    assert not is_training()
+    y = mx.nd.Dropout(x, p=0.5)
+    assert (y.asnumpy() == x.asnumpy()).all()
+
+    with train_mode():
+        assert not is_recording()
+        assert is_training()
+        y = mx.nd.Dropout(x, p=0.5)
+        assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0
+
 
 if __name__ == "__main__":
     import nose

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to