piiswrong closed pull request #8701: add inline for cached op & fixed a bug 
when calling backward on variable
URL: https://github.com/apache/incubator-mxnet/pull/8701
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 9815786ba7..faa453529e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -793,8 +793,15 @@ MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, 
SymbolHandle *out);
 /*!
  * \brief create cached operator
  */
-MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
-                               CachedOpHandle *out);
+MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out);
+/*!
+ * \brief create cached operator
+ */
+MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
+                                 int num_params,
+                                 const char** keys,
+                                 const char** vals,
+                                 CachedOpHandle *out);
 /*!
  * \brief free cached operator
  */
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 88a9f4d597..d605e9d47c 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -28,11 +28,30 @@
 #include <nnvm/graph.h>
 #include <vector>
 #include <atomic>
+#include <utility>
+#include <string>
 #include <unordered_map>
 
 #include "./ndarray.h"
 
 namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpParam : public dmlc::Parameter<CachedOpParam> {
+  uint32_t inline_limit;
+  uint32_t forward_bulk_size;
+  uint32_t backward_bulk_size;
+  DMLC_DECLARE_PARAMETER(CachedOpParam) {
+    DMLC_DECLARE_FIELD(inline_limit)
+    .set_default(2)
+    .describe("Maximum number of operators that can be inlined.");
+    DMLC_DECLARE_FIELD(forward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during forward pass.");
+    DMLC_DECLARE_FIELD(backward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during backward pass.");
+  }
+};
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -77,7 +96,8 @@ class Imperative {
   };
   class CachedOp {
    public:
-    explicit CachedOp(const nnvm::Symbol& sym);
+    CachedOp(const nnvm::Symbol& sym,
+             const std::vector<std::pair<std::string, std::string> >& kwargs);
     uint32_t num_inputs() {
       return fwd_graph_.indexed_graph().input_nodes().size();
     }
@@ -103,8 +123,9 @@ class Imperative {
                                  const std::vector<NDArray*>& inputs);
     std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
                                           const std::vector<nnvm::NodeEntry>& 
ograds);
-    OpStatePtr Forward(const std::vector<NDArray*>& inputs,
-                       const std::vector<NDArray*>& outputs);
+    void Forward(const std::shared_ptr<CachedOp>& op_ptr,
+                 const std::vector<NDArray*>& inputs,
+                 const std::vector<NDArray*>& outputs);
     void Backward(const bool retain_graph,
                   const OpStatePtr& state,
                   const std::vector<NDArray*>& inputs,
@@ -117,9 +138,11 @@ class Imperative {
       std::vector<OpStatePtr> states;
     };
     std::mutex mutex_;
+    CachedOpParam param_;
     nnvm::Graph fwd_graph_;
     nnvm::Graph grad_graph_;
     nnvm::Graph full_graph_;
+    bool inlining_;
     std::vector<nnvm::NodeEntry> ograd_entries_;
     std::vector<bool> curr_grad_req_;
     std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
@@ -182,7 +205,11 @@ class Imperative {
  private:
   friend class NDArray;
   /*! \brief make constructor protected. */
-  Imperative() {}
+  Imperative() {
+    if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
+      backward_bulk_size_ =  
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
+    }
+  }
   /*! \brief find the input/output ndarrays that are needed for backward */
   void GetBackwardDependency(
       const nnvm::NodePtr& node,
@@ -210,6 +237,8 @@ class Imperative {
   std::atomic<uint64_t> node_count_{0};
   /*! \brief variable count used for naming */
   std::atomic<uint64_t> variable_count_{0};
+  /*! \brief default backward bulk size */
+  int backward_bulk_size_{0};
 };
 
 using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index a0c01a6e06..20ad2bfbc5 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,10 +105,13 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym):
+    def __init__(self, sym, flags=()):
         self.handle = CachedOpHandle()
-        check_call(_LIB.MXCreateCachedOp(
+        check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
+            len(flags),
+            c_str_array([key for key, _ in flags]),
+            c_str_array([str(val) for _, val in flags]),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 466f87fade..37734ac389 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -274,7 +274,7 @@ def initialize(self, init=initializer.Uniform(), ctx=None, 
verbose=False):
         """
         self.collect_params().initialize(init, ctx, verbose)
 
-    def hybridize(self, active=True):
+    def hybridize(self, active=True, **kwargs):
         """Activates or deactivates :py:class:`HybridBlock` s recursively. Has 
no effect on
         non-hybrid children.
 
@@ -282,9 +282,11 @@ def hybridize(self, active=True):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
+        **kwargs : string
+            Additional flags for hybridized operator.
         """
         for cld in self._children:
-            cld.hybridize(active)
+            cld.hybridize(active, **kwargs)
 
     def cast(self, dtype):
         """Cast this Block to use another data type.
@@ -343,6 +345,7 @@ def __init__(self, prefix=None, params=None):
         self._out_format = None
         self._in_format = None
         self._active = False
+        self._flags = {}
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -378,7 +381,7 @@ def _get_graph(self, *args):
     def _build_cache(self, *args):
         inputs, out = self._get_graph(*args)
         input_idx = {var.name: i for i, var in enumerate(inputs)}
-        self._cached_op = ndarray.CachedOp(out)
+        self._cached_op = ndarray.CachedOp(out, self._flags)
         params = dict(self.collect_params().items())
 
         # verify graph inputs
@@ -437,9 +440,11 @@ def register_child(self, block):
         super(HybridBlock, self).register_child(block)
         self._clear_cached_op()
 
-    def hybridize(self, active=True):
+    def hybridize(self, active=True, **kwargs):
         self._active = active
-        super(HybridBlock, self).hybridize(active)
+        self._flags = kwargs.items()
+        self._clear_cached_op()
+        super(HybridBlock, self).hybridize(active, **kwargs)
 
     def cast(self, dtype):
         self._clear_cached_op()
@@ -615,5 +620,10 @@ def forward(self, x, *args):
         ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], 
args)})
         return _regroup(list(ret), self._out_format)[0]
 
+    def _clear_cached_op(self):
+        tmp = self._cached_graph
+        super(SymbolBlock, self)._clear_cached_op()
+        self._cached_graph = tmp
+
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index c0b4b52382..ab5d5e167f 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -68,7 +68,7 @@ def __getitem__(self, key):
     def __len__(self):
         return len(self._children)
 
-    def hybridize(self, active=True):
+    def hybridize(self, active=True, **kwargs):
         """Activates or deactivates `HybridBlock`s recursively. Has no effect 
on
         non-hybrid children.
 
@@ -76,11 +76,13 @@ def hybridize(self, active=True):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
+        **kwargs : string
+            Additional flags for hybridized operator.
         """
         if self._children and all(isinstance(c, HybridBlock) for c in 
self._children):
             warnings.warn('All children of this Sequential layer are 
HybridBlocks. Consider ' \
                           'using HybridSequential for the best performance.')
-        super(Sequential, self).hybridize(active)
+        super(Sequential, self).hybridize(active, **kwargs)
 
 
 class HybridSequential(HybridBlock):
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 2c4a305011..51f30e2231 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -157,7 +157,25 @@ int MXCreateCachedOp(SymbolHandle handle,
 
   API_BEGIN();
   *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(*sym));
+      new Imperative::CachedOp(
+        *sym, std::vector<std::pair<std::string, std::string> >()));
+  API_END();
+}
+
+int MXCreateCachedOpEx(SymbolHandle handle,
+                       int num_params,
+                       const char** keys,
+                       const char** vals,
+                       CachedOpHandle *out) {
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
+
+  API_BEGIN();
+  std::vector<std::pair<std::string, std::string> > kwargs;
+  for (int i = 0; i < num_params; ++i) {
+    kwargs.push_back({keys[i], vals[i]});
+  }
+  *out = new std::shared_ptr<Imperative::CachedOp>(
+      new Imperative::CachedOp(*sym, kwargs));
   API_END();
 }
 
@@ -198,16 +216,7 @@ int MXInvokeCachedOp(CachedOpHandle handle,
     }
   }
 
-  OpStatePtr state = op->Forward(ndinputs, ndoutputs);
-  if (Imperative::Get()->is_recording()) {
-    nnvm::NodeAttrs attrs;
-    attrs.op = cached_op;
-    attrs.name = "_cachedop";
-    attrs.parsed = op;
-    Imperative::Get()->RecordOp(
-        std::move(attrs), ndinputs, ndoutputs, state,
-        &op->save_inputs(), &op->save_outputs());
-  }
+  op->Forward(op, ndinputs, ndoutputs);
 
   if (*outputs == nullptr) {
     ret->ret_handles.clear();
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5717327f87..eaa95a5f24 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,12 +22,18 @@
 
 namespace mxnet {
 
-Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
+DMLC_REGISTER_PARAMETER(CachedOpParam);
+
+Imperative::CachedOp::CachedOp(
+    const nnvm::Symbol& sym,
+    const std::vector<std::pair<std::string, std::string> >& kwargs) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
   static const auto _copy = Op::Get("_copy");
 
+  param_.Init(kwargs);
+
   // construct forward graph
   {
     NodeEntryMap<int> dedup_out;
@@ -59,6 +65,8 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
 
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
+
+    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= 
param_.inline_limit;
   }
 
   // construct backward graph
@@ -321,13 +329,16 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
   return g;
 }
 
-OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,
-                                         const std::vector<NDArray*>& outputs) 
{
+void Imperative::CachedOp::Forward(
+    const std::shared_ptr<CachedOp>& op_ptr,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
+  static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
-  bool recording = Imperative::Get()->set_is_recording(false);
   // Initialize
+  bool recording = Imperative::Get()->is_recording();
   nnvm::Graph g = GetForwardGraph(recording, inputs);
   const auto& idx = g.indexed_graph();
   size_t num_inputs = idx.input_nodes().size();
@@ -381,10 +392,16 @@ OpStatePtr Imperative::CachedOp::Forward(const 
std::vector<NDArray*>& inputs,
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
+  if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size);
+
   Imperative::Get()->RunGraph(
       false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
       std::move(ref_count), &states, dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+  Imperative::Get()->set_is_recording(recording);
+
   for (size_t i = 0; i < idx.num_node_entries(); ++i) {
     if (arrays[i] == &buff[i]) continue;
     buff[i].shape_ = arrays[i]->shape_;
@@ -392,9 +409,15 @@ OpStatePtr Imperative::CachedOp::Forward(const 
std::vector<NDArray*>& inputs,
     buff[i].storage_type_ = arrays[i]->storage_type_;
   }
 
-  Imperative::Get()->set_is_recording(recording);
-
-  return op_state_ptr;
+  if (recording && !inlining_) {
+    nnvm::NodeAttrs attrs;
+    attrs.op = cached_op;
+    attrs.name = "_cachedop";
+    attrs.parsed = op_ptr;
+    Imperative::Get()->RecordOp(
+        std::move(attrs), inputs, outputs, op_state_ptr,
+        &save_inputs(), &save_outputs());
+  }
 }
 
 
@@ -452,10 +475,14 @@ void Imperative::CachedOp::Backward(
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
+  int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size);
+
   Imperative::Get()->RunGraph(
       retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
       std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+
   if (retain_graph) {
     buff.resize(num_forward_entries);
   } else {
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 361b971a2d..fbbaf82d17 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -288,8 +288,6 @@ void Imperative::RunGraph(
   DTypeVector arg_dtypes;
   std::vector<OpReqType> req;
 
-  int prev_bulk_size = Engine::Get()->set_bulk_size(10);
-
   for (size_t i = node_start; i < node_end; ++i) {
     const nnvm::IndexedGraph::Node& node = idx[i];
     if (node.source->op() == nullptr) continue;
@@ -353,8 +351,6 @@ void Imperative::RunGraph(
       if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
     }
   }
-
-  Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
 
@@ -367,6 +363,7 @@ std::vector<NDArray*> Imperative::Backward(
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
+  static const Op* copy_op = Op::Get("_copy");
 
   // Construct forward graph
   Graph graph;
@@ -439,7 +436,14 @@ std::vector<NDArray*> Imperative::Backward(
       zero_ops, "_copy");
   CHECK_EQ(g_graph.outputs.size(), xs.size());
   for (const auto &e : g_graph.outputs) {
-    graph.outputs.push_back(e);
+    if (e.node->op() == nullptr) {
+      auto node = Node::Create();
+      node->attrs.op = copy_op;
+      node->inputs.push_back(e);
+      graph.outputs.push_back(NodeEntry{node, 0, 0});
+    } else {
+      graph.outputs.push_back(e);
+    }
   }
   const auto& idx = graph.indexed_graph();
   // get number of nodes used in forward pass
@@ -575,10 +579,12 @@ std::vector<NDArray*> Imperative::Backward(
 
   bool prev_recording = set_is_recording(create_graph);
   bool prev_training = set_is_training(is_train);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
 
   RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
            std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
   set_is_recording(prev_recording);
   set_is_training(prev_training);
 
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index df9f78e0ce..c619056c11 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -23,6 +23,7 @@
 from nose.tools import raises
 from copy import deepcopy
 import warnings
+import json
 
 
 def test_parameter():
@@ -256,7 +257,6 @@ def test_deconv():
     # # check_layer_forward(layer, (1, 10, 10, 10, 4))
 
 
-
 def test_pool():
     layers1d = [
         nn.MaxPool1D(),
@@ -611,6 +611,31 @@ def test_fill_shape_load():
     assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
 
 
+def test_inline():
+    net = mx.gluon.nn.HybridSequential()
+    with net.name_scope():
+        net.add(mx.gluon.nn.Dense(10))
+        net.add(mx.gluon.nn.Dense(10))
+        net.add(mx.gluon.nn.Dense(10))
+
+    net.initialize()
+    net.hybridize(inline_limit=3)
+    with mx.autograd.record():
+        y = net(mx.nd.zeros((1,10)))
+
+    len_1 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes'])
+    y.backward()
+
+    net.hybridize(inline_limit=0)
+    with mx.autograd.record():
+        y = net(mx.nd.zeros((1,10)))
+
+    len_2 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes'])
+    y.backward()
+
+    assert len_1 == len_2 + 2
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to