This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f8c1e8  add inline for cached op & fixed a bug when calling backward 
on variable (#8701)
2f8c1e8 is described below

commit 2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Tue Nov 28 10:44:47 2017 -0800

    add inline for cached op & fixed a bug when calling backward on variable 
(#8701)
    
    * add inline for cached op & fixed a bug when calling backward on variable
    
    * fix
    
    * Update test_gluon.py
---
 include/mxnet/c_api.h                 | 11 ++++++++--
 include/mxnet/imperative.h            | 37 +++++++++++++++++++++++++++----
 python/mxnet/_ctypes/ndarray.py       |  7 ++++--
 python/mxnet/gluon/block.py           | 20 ++++++++++++-----
 python/mxnet/gluon/nn/basic_layers.py |  6 +++--
 src/c_api/c_api_ndarray.cc            | 31 ++++++++++++++++----------
 src/imperative/cached_op.cc           | 41 +++++++++++++++++++++++++++++------
 src/imperative/imperative.cc          | 16 +++++++++-----
 tests/python/unittest/test_gluon.py   | 27 ++++++++++++++++++++++-
 9 files changed, 157 insertions(+), 39 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 9815786..faa4535 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -793,8 +793,15 @@ MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, 
SymbolHandle *out);
 /*!
  * \brief create cached operator
  */
-MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
-                               CachedOpHandle *out);
+MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out);
+/*!
+ * \brief create cached operator
+ */
+MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
+                                 int num_params,
+                                 const char** keys,
+                                 const char** vals,
+                                 CachedOpHandle *out);
 /*!
  * \brief free cached operator
  */
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 88a9f4d..d605e9d 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -28,11 +28,30 @@
 #include <nnvm/graph.h>
 #include <vector>
 #include <atomic>
+#include <utility>
+#include <string>
 #include <unordered_map>
 
 #include "./ndarray.h"
 
 namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpParam : public dmlc::Parameter<CachedOpParam> {
+  uint32_t inline_limit;
+  uint32_t forward_bulk_size;
+  uint32_t backward_bulk_size;
+  DMLC_DECLARE_PARAMETER(CachedOpParam) {
+    DMLC_DECLARE_FIELD(inline_limit)
+    .set_default(2)
+    .describe("Maximum number of operators that can be inlined.");
+    DMLC_DECLARE_FIELD(forward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during forward pass.");
+    DMLC_DECLARE_FIELD(backward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during backward pass.");
+  }
+};
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -77,7 +96,8 @@ class Imperative {
   };
   class CachedOp {
    public:
-    explicit CachedOp(const nnvm::Symbol& sym);
+    CachedOp(const nnvm::Symbol& sym,
+             const std::vector<std::pair<std::string, std::string> >& kwargs);
     uint32_t num_inputs() {
       return fwd_graph_.indexed_graph().input_nodes().size();
     }
@@ -103,8 +123,9 @@ class Imperative {
                                  const std::vector<NDArray*>& inputs);
     std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
                                           const std::vector<nnvm::NodeEntry>& 
ograds);
-    OpStatePtr Forward(const std::vector<NDArray*>& inputs,
-                       const std::vector<NDArray*>& outputs);
+    void Forward(const std::shared_ptr<CachedOp>& op_ptr,
+                 const std::vector<NDArray*>& inputs,
+                 const std::vector<NDArray*>& outputs);
     void Backward(const bool retain_graph,
                   const OpStatePtr& state,
                   const std::vector<NDArray*>& inputs,
@@ -117,9 +138,11 @@ class Imperative {
       std::vector<OpStatePtr> states;
     };
     std::mutex mutex_;
+    CachedOpParam param_;
     nnvm::Graph fwd_graph_;
     nnvm::Graph grad_graph_;
     nnvm::Graph full_graph_;
+    bool inlining_;
     std::vector<nnvm::NodeEntry> ograd_entries_;
     std::vector<bool> curr_grad_req_;
     std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
@@ -182,7 +205,11 @@ class Imperative {
  private:
   friend class NDArray;
   /*! \brief make constructor protected. */
-  Imperative() {}
+  Imperative() {
+    if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
+      backward_bulk_size_ =  
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
+    }
+  }
   /*! \brief find the input/output ndarrays that are needed for backward */
   void GetBackwardDependency(
       const nnvm::NodePtr& node,
@@ -210,6 +237,8 @@ class Imperative {
   std::atomic<uint64_t> node_count_{0};
   /*! \brief variable count used for naming */
   std::atomic<uint64_t> variable_count_{0};
+  /*! \brief default backward bulk size */
+  int backward_bulk_size_{0};
 };
 
 using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index a0c01a6..20ad2bf 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,10 +105,13 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym):
+    def __init__(self, sym, flags=()):
         self.handle = CachedOpHandle()
-        check_call(_LIB.MXCreateCachedOp(
+        check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
+            len(flags),
+            c_str_array([key for key, _ in flags]),
+            c_str_array([str(val) for _, val in flags]),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 466f87f..37734ac 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -274,7 +274,7 @@ class Block(object):
         """
         self.collect_params().initialize(init, ctx, verbose)
 
-    def hybridize(self, active=True):
+    def hybridize(self, active=True, **kwargs):
         """Activates or deactivates :py:class:`HybridBlock` s recursively. Has 
no effect on
         non-hybrid children.
 
@@ -282,9 +282,11 @@ class Block(object):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
+        **kwargs : string
+            Additional flags for hybridized operator.
         """
         for cld in self._children:
-            cld.hybridize(active)
+            cld.hybridize(active, **kwargs)
 
     def cast(self, dtype):
         """Cast this Block to use another data type.
@@ -343,6 +345,7 @@ class HybridBlock(Block):
         self._out_format = None
         self._in_format = None
         self._active = False
+        self._flags = {}
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -378,7 +381,7 @@ class HybridBlock(Block):
     def _build_cache(self, *args):
         inputs, out = self._get_graph(*args)
         input_idx = {var.name: i for i, var in enumerate(inputs)}
-        self._cached_op = ndarray.CachedOp(out)
+        self._cached_op = ndarray.CachedOp(out, self._flags)
         params = dict(self.collect_params().items())
 
         # verify graph inputs
@@ -437,9 +440,11 @@ class HybridBlock(Block):
         super(HybridBlock, self).register_child(block)
         self._clear_cached_op()
 
-    def hybridize(self, active=True):
+    def hybridize(self, active=True, **kwargs):
         self._active = active
-        super(HybridBlock, self).hybridize(active)
+        self._flags = kwargs.items()
+        self._clear_cached_op()
+        super(HybridBlock, self).hybridize(active, **kwargs)
 
     def cast(self, dtype):
         self._clear_cached_op()
@@ -615,5 +620,10 @@ class SymbolBlock(HybridBlock):
         ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], 
args)})
         return _regroup(list(ret), self._out_format)[0]
 
+    def _clear_cached_op(self):
+        tmp = self._cached_graph
+        super(SymbolBlock, self)._clear_cached_op()
+        self._cached_graph = tmp
+
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index c0b4b52..ab5d5e1 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -68,7 +68,7 @@ class Sequential(Block):
     def __len__(self):
         return len(self._children)
 
-    def hybridize(self, active=True):
+    def hybridize(self, active=True, **kwargs):
         """Activates or deactivates `HybridBlock`s recursively. Has no effect 
on
         non-hybrid children.
 
@@ -76,11 +76,13 @@ class Sequential(Block):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
+        **kwargs : string
+            Additional flags for hybridized operator.
         """
         if self._children and all(isinstance(c, HybridBlock) for c in 
self._children):
             warnings.warn('All children of this Sequential layer are 
HybridBlocks. Consider ' \
                           'using HybridSequential for the best performance.')
-        super(Sequential, self).hybridize(active)
+        super(Sequential, self).hybridize(active, **kwargs)
 
 
 class HybridSequential(HybridBlock):
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 2c4a305..51f30e2 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -157,7 +157,25 @@ int MXCreateCachedOp(SymbolHandle handle,
 
   API_BEGIN();
   *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(*sym));
+      new Imperative::CachedOp(
+        *sym, std::vector<std::pair<std::string, std::string> >()));
+  API_END();
+}
+
+int MXCreateCachedOpEx(SymbolHandle handle,
+                       int num_params,
+                       const char** keys,
+                       const char** vals,
+                       CachedOpHandle *out) {
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
+
+  API_BEGIN();
+  std::vector<std::pair<std::string, std::string> > kwargs;
+  for (int i = 0; i < num_params; ++i) {
+    kwargs.push_back({keys[i], vals[i]});
+  }
+  *out = new std::shared_ptr<Imperative::CachedOp>(
+      new Imperative::CachedOp(*sym, kwargs));
   API_END();
 }
 
@@ -198,16 +216,7 @@ int MXInvokeCachedOp(CachedOpHandle handle,
     }
   }
 
-  OpStatePtr state = op->Forward(ndinputs, ndoutputs);
-  if (Imperative::Get()->is_recording()) {
-    nnvm::NodeAttrs attrs;
-    attrs.op = cached_op;
-    attrs.name = "_cachedop";
-    attrs.parsed = op;
-    Imperative::Get()->RecordOp(
-        std::move(attrs), ndinputs, ndoutputs, state,
-        &op->save_inputs(), &op->save_outputs());
-  }
+  op->Forward(op, ndinputs, ndoutputs);
 
   if (*outputs == nullptr) {
     ret->ret_handles.clear();
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5717327..eaa95a5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,12 +22,18 @@
 
 namespace mxnet {
 
-Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
+DMLC_REGISTER_PARAMETER(CachedOpParam);
+
+Imperative::CachedOp::CachedOp(
+    const nnvm::Symbol& sym,
+    const std::vector<std::pair<std::string, std::string> >& kwargs) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
   static const auto _copy = Op::Get("_copy");
 
+  param_.Init(kwargs);
+
   // construct forward graph
   {
     NodeEntryMap<int> dedup_out;
@@ -59,6 +65,8 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
 
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
+
+    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= 
param_.inline_limit;
   }
 
   // construct backward graph
@@ -321,13 +329,16 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
   return g;
 }
 
-OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,
-                                         const std::vector<NDArray*>& outputs) 
{
+void Imperative::CachedOp::Forward(
+    const std::shared_ptr<CachedOp>& op_ptr,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
+  static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
-  bool recording = Imperative::Get()->set_is_recording(false);
   // Initialize
+  bool recording = Imperative::Get()->is_recording();
   nnvm::Graph g = GetForwardGraph(recording, inputs);
   const auto& idx = g.indexed_graph();
   size_t num_inputs = idx.input_nodes().size();
@@ -381,10 +392,16 @@ OpStatePtr Imperative::CachedOp::Forward(const 
std::vector<NDArray*>& inputs,
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
+  if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size);
+
   Imperative::Get()->RunGraph(
       false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
       std::move(ref_count), &states, dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+  Imperative::Get()->set_is_recording(recording);
+
   for (size_t i = 0; i < idx.num_node_entries(); ++i) {
     if (arrays[i] == &buff[i]) continue;
     buff[i].shape_ = arrays[i]->shape_;
@@ -392,9 +409,15 @@ OpStatePtr Imperative::CachedOp::Forward(const 
std::vector<NDArray*>& inputs,
     buff[i].storage_type_ = arrays[i]->storage_type_;
   }
 
-  Imperative::Get()->set_is_recording(recording);
-
-  return op_state_ptr;
+  if (recording && !inlining_) {
+    nnvm::NodeAttrs attrs;
+    attrs.op = cached_op;
+    attrs.name = "_cachedop";
+    attrs.parsed = op_ptr;
+    Imperative::Get()->RecordOp(
+        std::move(attrs), inputs, outputs, op_state_ptr,
+        &save_inputs(), &save_outputs());
+  }
 }
 
 
@@ -452,10 +475,14 @@ void Imperative::CachedOp::Backward(
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
+  int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size);
+
   Imperative::Get()->RunGraph(
       retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
       std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+
   if (retain_graph) {
     buff.resize(num_forward_entries);
   } else {
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 361b971..fbbaf82 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -288,8 +288,6 @@ void Imperative::RunGraph(
   DTypeVector arg_dtypes;
   std::vector<OpReqType> req;
 
-  int prev_bulk_size = Engine::Get()->set_bulk_size(10);
-
   for (size_t i = node_start; i < node_end; ++i) {
     const nnvm::IndexedGraph::Node& node = idx[i];
     if (node.source->op() == nullptr) continue;
@@ -353,8 +351,6 @@ void Imperative::RunGraph(
       if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
     }
   }
-
-  Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
 
@@ -367,6 +363,7 @@ std::vector<NDArray*> Imperative::Backward(
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
+  static const Op* copy_op = Op::Get("_copy");
 
   // Construct forward graph
   Graph graph;
@@ -439,7 +436,14 @@ std::vector<NDArray*> Imperative::Backward(
       zero_ops, "_copy");
   CHECK_EQ(g_graph.outputs.size(), xs.size());
   for (const auto &e : g_graph.outputs) {
-    graph.outputs.push_back(e);
+    if (e.node->op() == nullptr) {
+      auto node = Node::Create();
+      node->attrs.op = copy_op;
+      node->inputs.push_back(e);
+      graph.outputs.push_back(NodeEntry{node, 0, 0});
+    } else {
+      graph.outputs.push_back(e);
+    }
   }
   const auto& idx = graph.indexed_graph();
   // get number of nodes used in forward pass
@@ -575,10 +579,12 @@ std::vector<NDArray*> Imperative::Backward(
 
   bool prev_recording = set_is_recording(create_graph);
   bool prev_training = set_is_training(is_train);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
 
   RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
            std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
   set_is_recording(prev_recording);
   set_is_training(prev_training);
 
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index df9f78e..c619056 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -23,6 +23,7 @@ import numpy as np
 from nose.tools import raises
 from copy import deepcopy
 import warnings
+import json
 
 
 def test_parameter():
@@ -256,7 +257,6 @@ def test_deconv():
     # # check_layer_forward(layer, (1, 10, 10, 10, 4))
 
 
-
 def test_pool():
     layers1d = [
         nn.MaxPool1D(),
@@ -611,6 +611,31 @@ def test_fill_shape_load():
     assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
 
 
+def test_inline():
+    net = mx.gluon.nn.HybridSequential()
+    with net.name_scope():
+        net.add(mx.gluon.nn.Dense(10))
+        net.add(mx.gluon.nn.Dense(10))
+        net.add(mx.gluon.nn.Dense(10))
+
+    net.initialize()
+    net.hybridize(inline_limit=3)
+    with mx.autograd.record():
+        y = net(mx.nd.zeros((1,10)))
+
+    len_1 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes'])
+    y.backward()
+
+    net.hybridize(inline_limit=0)
+    with mx.autograd.record():
+        y = net(mx.nd.zeros((1,10)))
+
+    len_2 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes'])
+    y.backward()
+
+    assert len_1 == len_2 + 2
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to