This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a7ea976  Static alloc for hybridblock (#11320)
a7ea976 is described below

commit a7ea97651edcb87668541efe36b35a87b7b2b50e
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Wed Jun 20 14:08:24 2018 -0700

    Static alloc for hybridblock (#11320)
    
    * Revert "Revert "Static alloc for hybridblock (#11313)" (#11318)"
    
    This reverts commit 9b8eb567273d9521c639ea783e1117fe09cb3d64.
    
    * fix
---
 include/mxnet/c_api.h                   |   5 -
 include/mxnet/imperative.h              |  89 ----
 include/mxnet/ndarray.h                 |   8 +
 include/mxnet/op_attr_types.h           |  33 +-
 python/mxnet/_ctypes/ndarray.py         |  16 +-
 python/mxnet/gluon/block.py             |  70 +--
 src/c_api/c_api_ndarray.cc              |  26 +-
 src/engine/threaded_engine.cc           |   3 +-
 src/executor/attach_op_execs_pass.cc    | 165 +++----
 src/executor/attach_op_resource_pass.cc |  16 +-
 src/executor/exec_pass.h                |  28 +-
 src/executor/graph_executor.cc          |   2 +-
 src/imperative/cached_op.cc             | 754 ++++++++++++++++++++++++++------
 src/imperative/cached_op.h              | 174 ++++++++
 src/imperative/imperative.cc            |  90 +---
 src/imperative/imperative_utils.cc      | 120 +++++
 src/imperative/imperative_utils.h       | 256 +++++++++--
 tests/python/unittest/test_gluon.py     |  67 ++-
 18 files changed, 1399 insertions(+), 523 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 55c26bc..4dd858a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -987,11 +987,6 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
                                  int num_flags,
                                  const char** keys,
                                  const char** vals,
-                                 int num_inputs,
-                                 const char** input_names,
-                                 int num_params,
-                                 const char** param_names,
-                                 NDArrayHandle* params,
                                  CachedOpHandle *out);
 /*!
  * \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 758ce85..7ea60df 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,23 +35,6 @@
 #include "./ndarray.h"
 
 namespace mxnet {
-/*! \brief CachedOp Parameters */
-struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
-  uint32_t inline_limit;
-  uint32_t forward_bulk_size;
-  uint32_t backward_bulk_size;
-  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
-    DMLC_DECLARE_FIELD(inline_limit)
-    .set_default(2)
-    .describe("Maximum number of operators that can be inlined.");
-    DMLC_DECLARE_FIELD(forward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during forward pass.");
-    DMLC_DECLARE_FIELD(backward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during backward pass.");
-  }
-};
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -94,67 +77,6 @@ class Imperative {
              && info.out_grads.size() == 1;
     }
   };
-  class CachedOp {
-   public:
-    CachedOp(
-        const nnvm::Symbol& sym,
-        const std::vector<std::pair<std::string, std::string> >& flags,
-        const std::vector<std::string> arg_names,
-        const std::unordered_map<std::string, std::vector<NDArray> >& params);
-    uint32_t num_inputs() {
-      return fwd_graph_.indexed_graph().input_nodes().size();
-    }
-    uint32_t num_outputs() {
-      return fwd_graph_.outputs.size();
-    }
-    uint32_t num_backward_inputs() {
-      return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
-    }
-    std::vector<bool>& save_inputs() {
-      return save_inputs_;
-    }
-    std::vector<bool>& save_outputs() {
-      return save_outputs_;
-    }
-    const std::unordered_set<uint32_t>& mutable_input_nodes() {
-      return fwd_graph_.indexed_graph().mutable_input_nodes();
-    }
-    nnvm::Graph GetForwardGraph(const bool recording,
-                                const std::vector<NDArray*>& inputs);
-    nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
-                                 const std::vector<OpReqType>& reqs,
-                                 const std::vector<NDArray*>& inputs);
-    std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
-                                          const std::vector<nnvm::NodeEntry>& 
ograds);
-    void Forward(const std::shared_ptr<CachedOp>& op_ptr,
-                 const std::vector<NDArray*>& args,
-                 const std::vector<NDArray*>& outputs);
-    void Backward(const bool retain_graph,
-                  const OpStatePtr& state,
-                  const std::vector<NDArray*>& inputs,
-                  const std::vector<OpReqType>& reqs,
-                  const std::vector<NDArray*>& outputs);
-
-   private:
-    struct CachedOpState {
-      std::vector<NDArray> buff;
-      std::vector<OpStatePtr> states;
-    };
-    std::mutex mutex_;
-    CachedOpConfig config_;
-    nnvm::Graph fwd_graph_;
-    nnvm::Graph grad_graph_;
-    nnvm::Graph full_graph_;
-    std::unordered_map<Context, std::vector<NDArray> > params_;
-    bool inlining_;
-    std::vector<nnvm::NodeEntry> ograd_entries_;
-    std::vector<bool> curr_grad_req_;
-    std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
-    std::vector<uint32_t> fwd_args_idx_;
-    std::vector<uint32_t> fwd_params_idx_;
-    std::vector<uint32_t> bwd_input_eid_;
-    std::vector<bool> save_inputs_, save_outputs_;
-  };
   /*! \brief whether operator recording is on. */
   bool is_training() const {
     return is_train_;
@@ -222,15 +144,6 @@ class Imperative {
       uint32_t num_inputs, uint32_t num_outputs,
       std::vector<bool> *p_save_inputs,
       std::vector<bool> *p_save_outputs);
-  void RunGraph(
-      const bool retain_graph,
-      const nnvm::IndexedGraph& idx,
-      const std::vector<NDArray*> arrays,
-      size_t node_start, size_t node_end,
-      std::vector<OpReqType>&& array_reqs,
-      std::vector<uint32_t>&& ref_count,
-      std::vector<OpStatePtr> *p_states,
-      const DispatchModeVector& dispatch_modes);
   /*! \brief indicate whether is training. */
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local bool is_train_;
@@ -247,7 +160,5 @@ class Imperative {
   int backward_bulk_size_{0};
 };
 
-using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
-
 }  // namespace mxnet
 #endif  // MXNET_IMPERATIVE_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index e243eb7..ae96fd8 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -155,6 +155,14 @@ class NDArray {
     return byte_offset_ > 0 || shape() != ptr_->storage_shape;
   }
 
+  /* \brief Check whether the two arrays are the same array */
+  inline bool IsSame(const NDArray& other) {
+    return ptr_ == other.ptr_ &&
+        shape_ == other.shape_ &&
+        byte_offset_ == other.byte_offset_ &&
+        dtype_ == other.dtype_;
+  }
+
   /*!
    * \return the shape of current NDArray.
    */
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 3969d84..f4694ef 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -126,25 +126,36 @@ class OpStatePtr {
   template<typename T, typename... Args>
   static OpStatePtr Create(Args&&... args) {
     OpStatePtr ret;
-    ret.ptr_ = std::make_shared<OpState>();
-    ret.ptr_->var_ = Engine::Get()->NewVariable();
-    ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
+    auto state = new T(std::forward<Args>(args)...);
+    auto var = Engine::Get()->NewVariable();
+    ret.ptr_.reset(
+      new OpState(var, state),
+      [](OpState* p) {
+        Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), 
p->var);
+        delete reinterpret_cast<T*>(p->state);
+        delete p;
+      });
 
     return ret;
   }
   /* \brief Get engine variable associated with this state */
   engine::VarHandle get_var() const {
-    return ptr_->var_;
+    return ptr_->var;
   }
   /* \brief Get state of type T */
   template<typename T>
   T& get_state() const {
-    return dmlc::get<T>(ptr_->state_);
+    return *reinterpret_cast<T*>(ptr_->state);
   }
   /* \brief clear state */
   void reset() {
     ptr_.reset();
   }
+  /* \brief checks whether the managed object is managed only by the current
+            OpStatePtr instance */
+  bool unique() const {
+    return ptr_.unique();
+  }
   /* \brief Whether state is empty */
   explicit operator bool() const {
     return ptr_ ? true : false;
@@ -153,16 +164,12 @@ class OpStatePtr {
  private:
   /* \brief state structure */
   struct OpState {
-    OpState() {}
+    engine::VarHandle var;
+    void* state;
+
+    OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
     OpState(const OpState& other) = delete;
     OpState& operator=(const OpState& other) = delete;
-
-    ~OpState() {
-      Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
-    }
-
-    engine::VarHandle var_;
-    dmlc::any state_;
   };
   /* \brief shared pointer to state */
   std::shared_ptr<OpState> ptr_;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index d2cae0c..f324545 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,28 +105,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym, flags=(), inputs=None, params=None):
+    def __init__(self, sym, flags=()):
         self.handle = CachedOpHandle()
-        param_names = []
-        param_arrays = []
-        if inputs is None:
-            assert params is None, "When inputs is None params must also be 
None."
-            inputs = sym.list_inputs()
-        elif params is not None:
-            for name, arrs in params.items():
-                param_arrays.extend(arrs)
-                param_names.extend([name] * len(arrs))
 
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
             len(flags),
             c_str_array([key for key, _ in flags]),
             c_str_array([str(val) for _, val in flags]),
-            len(inputs),
-            c_str_array(inputs),
-            len(param_names),
-            c_str_array(param_names),
-            c_handle_array(param_arrays),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 3b97c05..0845669 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -502,8 +502,12 @@ class Block(object):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
-        **kwargs : string
-            Additional flags for hybridized operator.
+        static_alloc : bool, default False
+            Statically allocate memory to improve speed. Memory usage may 
increase.
+        static_shape : bool, default False
+            Optimize for invariant input shapes between iterations. Must also
+            set static_alloc to True. Change of input shapes is still allowed
+            but slower.
         """
         for cld in self._children.values():
             cld.hybridize(active, **kwargs)
@@ -696,7 +700,7 @@ class HybridBlock(Block):
         self._out_format = None
         self._in_format = None
         self._active = False
-        self._flags = {}
+        self._flags = []
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -723,39 +727,43 @@ class HybridBlock(Block):
         return self._cached_graph
 
     def _build_cache(self, *args):
-        inputs, out = self._get_graph(*args)
-        input_names = [i.name for i in inputs]
-
+        data, out = self._get_graph(*args)
+        data_names = {data.name : i for i, data in enumerate(data)}
         params = self.collect_params()
+        input_names = out.list_inputs()
+
         param_names = set(params.keys())
-        expected_names = set(out.list_inputs())
+        expected_names = set(input_names)
         for name in expected_names:
-            assert name in param_names or name in input_names, \
+            assert name in param_names or name in data_names, \
                 "Unknown input to HybridBlock: %s"%name
 
-        used_input_names = [i for i in input_names if i in expected_names]
-        if len(used_input_names) != len(input_names):
-            unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
+        used_data_names = [i for i in data_names if i in expected_names]
+        if len(used_data_names) != len(data_names):
+            unused = ', '.join(['%d-th'%i for name, i in data_names.items()
                                 if name not in expected_names])
             warnings.warn("The %s input to HybridBlock is not used by any "
                           "computation. Is this intended?"%unused, 
stacklevel=4)
 
-        used_param_names = set(i for i in param_names if i in expected_names)
+        used_param_names = [i for i in param_names if i in expected_names]
         if len(used_param_names) != len(param_names):
-            unused = ', '.join(list(param_names - used_param_names))
+            unused = ', '.join(list(param_names - set(used_param_names)))
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
-        used_params = {k: params[k] for k in used_param_names}
-        try:
-            param_dict = {k: v.list_data() for k, v in used_params.items()}
-        except DeferredInitializationError:
-            self._deferred_infer_shape(*args)
-            for i in used_params.values():
-                i._finish_deferred_init()
-            param_dict = {k: v.list_data() for k, v in used_params.items()}
-
-        self._cached_op = ndarray.CachedOp(out, self._flags, input_names, 
param_dict)
+        data_indices = []
+        param_indices = []
+        self._cached_op_args = []
+        for i, name in enumerate(input_names):
+            if name in data_names:
+                data_indices.append(i)
+                self._cached_op_args.append((True, data_names[name]))
+            else:
+                param_indices.append(i)
+                self._cached_op_args.append((False, params[name]))
+        flags = [('data_indices', data_indices), ('param_indices', 
param_indices)] + \
+                self._flags
+        self._cached_op = ndarray.CachedOp(out, flags)
 
     def _deferred_infer_shape(self, *args):
         try:
@@ -771,7 +779,19 @@ class HybridBlock(Block):
 
         args, fmt = _flatten(args, "input")
         assert fmt == self._in_format, "Invalid input format"
-        out = self._cached_op(*args)
+        try:
+            cargs = [args[i] if is_arg else i.data()
+                     for is_arg, i in self._cached_op_args]
+        except DeferredInitializationError:
+            self._deferred_infer_shape(*args)
+            cargs = []
+            for is_arg, i in self._cached_op_args:
+                if is_arg:
+                    cargs.append(args[i])
+                else:
+                    i._finish_deferred_init()
+                    cargs.append(i.data())
+        out = self._cached_op(*cargs)
         if isinstance(out, NDArray):
             out = [out]
         return _regroup(out, self._out_format)[0]
@@ -792,7 +812,7 @@ class HybridBlock(Block):
 
     def hybridize(self, active=True, **kwargs):
         self._active = active
-        self._flags = kwargs.items()
+        self._flags = list(kwargs.items())
         self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
             warnings.warn('"{}" is being hybridized while still having forward 
hook/pre-hook. '
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 9aabe04..34bd4b2 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -36,6 +36,7 @@
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
 #include "../imperative/imperative_utils.h"
+#include "../imperative/cached_op.h"
 
 using namespace mxnet;
 
@@ -160,12 +161,8 @@ int MXCreateCachedOp(SymbolHandle handle,
   std::vector<std::string> input_names;
   input_names.reserve(inputs.size());
   for (const auto& i : inputs) input_names.push_back(i->attrs.name);
-  *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(
-        *sym,
-        std::vector<std::pair<std::string, std::string> >(),
-        input_names,
-        std::unordered_map<std::string, std::vector<NDArray> >()));
+  *out = new CachedOpPtr(new CachedOp(
+      *sym, std::vector<std::pair<std::string, std::string> >()));
   API_END();
 }
 
@@ -173,11 +170,6 @@ int MXCreateCachedOpEx(SymbolHandle handle,
                        int num_flags,
                        const char** keys,
                        const char** vals,
-                       int num_args,
-                       const char** arg_names,
-                       int num_params,
-                       const char** param_names,
-                       NDArrayHandle* params,
                        CachedOpHandle *out) {
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
 
@@ -186,17 +178,7 @@ int MXCreateCachedOpEx(SymbolHandle handle,
   for (int i = 0; i < num_flags; ++i) {
     flags.push_back({keys[i], vals[i]});
   }
-  std::vector<std::string> args;
-  for (int i = 0; i < num_args; ++i) {
-    args.push_back(arg_names[i]);
-  }
-  std::unordered_map<std::string, std::vector<NDArray> > param_dict;
-  for (int i = 0; i < num_params; ++i) {
-    param_dict[param_names[i]].emplace_back(
-        *reinterpret_cast<NDArray*>(params[i]));
-  }
-  *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(*sym, flags, args, param_dict));
+  *out = new CachedOpPtr(new CachedOp(*sym, flags));
   API_END();
 }
 
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index dc0436e..e70cc19 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -278,6 +278,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
 }
 
 void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool 
profiling) {
+  BulkFlush();
+
   ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
   OprBlock* opr_block = OprBlock::New();
   opr_block->opr = threaded_opr;
@@ -323,7 +325,6 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
         << device_count_;
   }
 #endif
-  BulkFlush();
   ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, 
prop, opr_name, wait);
   opr->temporary = true;
   const bool profiling = 
profiler_->IsProfiling(profiler::Profiler::kImperative);
diff --git a/src/executor/attach_op_execs_pass.cc 
b/src/executor/attach_op_execs_pass.cc
index 697e486..72919d9 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -134,6 +134,10 @@ class StatefulComputeExecutor : public 
StorageFallbackOpExecutor {
     return state_.get_var();
   }
 
+  OpStatePtr state() const override {
+    return state_;
+  }
+
   explicit StatefulComputeExecutor(const OpStatePtr& state,
                                    const FStatefulCompute& fcompute,
                                    ExecType exec_type,
@@ -142,7 +146,6 @@ class StatefulComputeExecutor : public 
StorageFallbackOpExecutor {
         state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
-  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulCompute fcompute_;
   ExecType exec_type_;
@@ -170,13 +173,16 @@ class StatefulComputeExExecutor : public OpExecutor {
     return state_.get_var();
   }
 
+  OpStatePtr state() const override {
+    return state_;
+  }
+
   explicit StatefulComputeExExecutor(const OpStatePtr& state,
                                      const FStatefulComputeEx& fcompute,
                                      ExecType exec_type)
       : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
-  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulComputeEx fcompute_;
   ExecType exec_type_;
@@ -241,16 +247,15 @@ class FComputeExExecutor : public OpExecutor {
   ExecType exec_type_;
 };
 
-// pass to attach operator executors
-Graph AttachOpExecs(Graph g) {
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
   using nnvm::DTypeVector;
   using nnvm::ShapeVector;
   using nnvm::FMutateInputs;
 
-  auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
-  auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
-  auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
+  static auto& fcreate_op_state = 
nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& fmutate_inputs = 
nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
+  static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+  static auto& is_layer_backward = 
nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
 
   const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
   const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -259,81 +264,87 @@ Graph AttachOpExecs(Graph g) {
 
   // get the graph
   const auto& idx = g.indexed_graph();
-  std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());
+  OpExecVector& ret = *p_ret;
 
   // initialize the nodes
-  for (size_t i = 0; i < idx.num_nodes(); ++i) {
-    const auto& inode = idx[i];
-    if (inode.source->is_variable()) continue;
-    const nnvm::Op *op = inode.source->op();
-    ExecType exec_type = ExecType::kSync;
-    std::vector<uint32_t> mutate_index;
-    if (fmutate_inputs.count(op)) {
-      mutate_index = fmutate_inputs[op](inode.source->attrs);
-    }
-    if (fexec_type.count(op)) {
-      exec_type = fexec_type[op](inode.source->attrs);
+  const auto& inode = idx[i];
+  if (inode.source->is_variable()) return;
+  const nnvm::Op *op = inode.source->op();
+  ExecType exec_type = ExecType::kSync;
+  std::vector<uint32_t> mutate_index;
+  if (fmutate_inputs.count(op)) {
+    mutate_index = fmutate_inputs[op](inode.source->attrs);
+  }
+  if (fexec_type.count(op)) {
+    exec_type = fexec_type[op](inode.source->attrs);
+  }
+  CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
+  if (fcreate_op_state.count(op)) {
+    std::vector<TShape> ishape;
+    std::vector<int> itype;
+    for (const auto& e : inode.inputs) {
+      ishape.emplace_back(vshape[idx.entry_id(e)]);
+      itype.emplace_back(vdtype[idx.entry_id(e)]);
     }
-    CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
-    if (fcreate_op_state.count(op)) {
-      std::vector<TShape> ishape;
-      std::vector<int> itype;
-      for (const auto& e : inode.inputs) {
-        ishape.emplace_back(vshape[idx.entry_id(e)]);
-        itype.emplace_back(vdtype[idx.entry_id(e)]);
-      }
 
-      OpStatePtr state = fcreate_op_state[op](
-          inode.source->attrs, vctx[i], ishape, itype);
-      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-          op, "FStatefulComputeEx", vctx[i]);
-      // FStatefulComputeEx is dispatched only when dispatch_mode is 
DispatchMode::kFComputeEx
-      if (fcompute_ex != nullptr && dispatch_modes[i] == 
DispatchMode::kFComputeEx) {
-        ret[i] = std::make_shared<StatefulComputeExExecutor>(state, 
fcompute_ex, exec_type);
-      } else {
-        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-            op, "FStatefulCompute", vctx[i]);
-        CHECK(fcompute != nullptr)
-            << "One of FStatefulCompute and FStatefulComputeEx must be 
registered "
-            << "for stateful operator " << op->name;
-        ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
-                                                           exec_type, 
mutate_index);
-      }
-    } else if (is_layer_backward.get(op, false)) {
-      CHECK_GE(inode.control_deps.size(), 1);
-      uint32_t fwd_id = inode.control_deps[0];
-      CHECK(vctx[fwd_id] == vctx[i]);
-      CHECK(ret[fwd_id] != nullptr);
-      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-          op, "FStatefulComputeEx", vctx[i]);
-      // FStatefulComputeEx is dispatched only when dispatch_mode is 
DispatchMode::kFComputeEx
-      if (fcompute_ex != nullptr && dispatch_modes[i] == 
DispatchMode::kFComputeEx) {
-        ret[i] = std::make_shared<StatefulComputeExExecutor>(
-            
dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
-            fcompute_ex, exec_type);
-      } else {
-        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-            op, "FStatefulCompute", vctx[i]);
-        CHECK(fcompute != nullptr)
-            << "One of FStatefulCompute and FStatefulComputeEx must be 
registered "
-            << "for stateful operator " << op->name;
-        ret[i] = std::make_shared<StatefulComputeExecutor>(
-            dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
-            fcompute, exec_type, mutate_index);
-      }
+    OpStatePtr state = fcreate_op_state[op](
+        inode.source->attrs, vctx[i], ishape, itype);
+    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+        op, "FStatefulComputeEx", vctx[i]);
+    // FStatefulComputeEx is dispatched only when dispatch_mode is 
DispatchMode::kFComputeEx
+    if (fcompute_ex != nullptr && dispatch_modes[i] == 
DispatchMode::kFComputeEx) {
+      ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, 
exec_type);
     } else {
-      FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", 
vctx[i]);
-      FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", 
vctx[i]);
-      if (fcomp_ex != nullptr && dispatch_modes[i] == 
DispatchMode::kFComputeEx) {
-        ret[i] = std::make_shared<FComputeExExecutor>(
-            inode.source->attrs, fcomp_ex, exec_type);
-      } else if (fcompute != nullptr) {
-        ret[i] = std::make_shared<FComputeExecutor>(
-            inode.source->attrs, fcompute, exec_type, mutate_index);
-      } else {
-        LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
-      }
+      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+          op, "FStatefulCompute", vctx[i]);
+      CHECK(fcompute != nullptr)
+          << "One of FStatefulCompute and FStatefulComputeEx must be 
registered "
+          << "for stateful operator " << op->name;
+      ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
+                                                         exec_type, 
mutate_index);
+    }
+  } else if (is_layer_backward.get(op, false)) {
+    CHECK_GE(inode.control_deps.size(), 1);
+    uint32_t fwd_id = inode.control_deps[0];
+    CHECK(vctx[fwd_id] == vctx[i]);
+    CHECK(ret[fwd_id] != nullptr);
+    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+        op, "FStatefulComputeEx", vctx[i]);
+    // FStatefulComputeEx is dispatched only when dispatch_mode is 
DispatchMode::kFComputeEx
+    if (fcompute_ex != nullptr && dispatch_modes[i] == 
DispatchMode::kFComputeEx) {
+      ret[i] = std::make_shared<StatefulComputeExExecutor>(
+          ret[fwd_id].get()->state(), fcompute_ex, exec_type);
+    } else {
+      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+          op, "FStatefulCompute", vctx[i]);
+      CHECK(fcompute != nullptr)
+          << "One of FStatefulCompute and FStatefulComputeEx must be 
registered "
+          << "for stateful operator " << op->name;
+      ret[i] = std::make_shared<StatefulComputeExecutor>(
+          ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
     }
+  } else {
+    FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
+    FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", 
vctx[i]);
+    if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) 
{
+      ret[i] = std::make_shared<FComputeExExecutor>(
+          inode.source->attrs, fcomp_ex, exec_type);
+    } else if (fcompute != nullptr) {
+      ret[i] = std::make_shared<FComputeExecutor>(
+          inode.source->attrs, fcompute, exec_type, mutate_index);
+    } else {
+      LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+    }
+  }
+}
+
+
+// pass to attach operator executors
+Graph AttachOpExecs(Graph g) {
+  const auto& idx = g.indexed_graph();
+  OpExecVector ret(idx.num_nodes());
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    CreateOpExecs(g, &ret, i);
   }
   g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
   return g;
diff --git a/src/executor/attach_op_resource_pass.cc 
b/src/executor/attach_op_resource_pass.cc
index 6818662..56122cd 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -30,12 +30,15 @@
 namespace mxnet {
 namespace exec {
 
-Graph AttachOpResources(Graph g) {
+void AttachOpResources(
+    const Graph& g,
+    const OpExecVector& op_execs,
+    size_t start_nid,
+    size_t end_nid) {
   static auto& fresource =
       nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
   static auto& fresource_ex =
       nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
-  auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
   const auto& vctx = g.GetAttr<ContextVector>("context");
   const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
   const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
@@ -43,7 +46,7 @@ Graph AttachOpResources(Graph g) {
   // Use global resource pool for each executor for now.
   std::map<Context, Resource> cached_temp;
   // Resource allocation
-  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
+  for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
     const auto& inode = idx[nid];
     if (inode.source->is_variable()) continue;
     const Context &ctx = vctx[nid];
@@ -84,7 +87,12 @@ Graph AttachOpResources(Graph g) {
       requested.push_back(ResourceManager::Get()->Request(ctx, 
ResourceRequest::kTempSpace));
     }
   }
-  return g;
 }
+
+void AttachOpResources(const Graph& g) {
+  const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
+  AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
+}
+
 }  // namespace exec
 }  // namespace mxnet
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 99b1b16..26a2491 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -82,6 +82,10 @@ class OpExecutor {
   virtual engine::VarHandle var() const {
     return nullptr;
   }
+  /*! \return return operator state */
+  virtual OpStatePtr state() const {
+    return OpStatePtr();
+  }
 };
 
 /*!
@@ -103,6 +107,14 @@ using ContextVector = std::vector<Context>;
 using DevMaskVector = std::vector<int>;
 
 /*!
+ * \brief create OpExecutor for a node in graph
+ *
+ * \param g input graph
+ * \param p_ret OpExecVector for input and output
+ * \param i the id of the node
+ */
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
+/*!
  * \brief Attach OpExecutor to the graph attributes.
  *
  * \param g input graph
@@ -115,12 +127,20 @@ Graph AttachOpExecs(Graph g);
  * \brief Attach Resource to the OpExecVector of the graph.
  *
  * \param g input graph need to contain op_exec attribute.
+ */
+void AttachOpResources(const Graph& g);
+/*!
+ * \brief Attach Resource to the OpExecVector
  *
- * \return graph with new attribute "op_exec" of type OpExecVector
- *  The fields on the OpExecVector are not yet been setup.
+ * \param g input graph
+ * \param op_execs OpExecutor vector
+ * \param start_nid starting node id
+ * \param end_nid end node id
  */
-Graph AttachOpResources(Graph g);
-
+void AttachOpResources(const Graph& g,
+                       const OpExecVector& op_execs,
+                       size_t start_nid,
+                       size_t end_nid);
 /*!
  * \brief Discover chance of inplace addto operators.
  *  i.e. z = plus(z, source_op), and encourage it to become z += source_op.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index e28867d..831b5f9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
   }
 
   g = AttachOpExecs(g);
-  g = AttachOpResources(g);
+  AttachOpResources(g);
   graph_ = std::move(g);
 
   if (shared_exec != nullptr) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 140b5a5..c0e5e83 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -19,16 +19,78 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
+#include "./cached_op.h"
+#include "../executor/exec_pass.h"
+#include "../profiler/profiler.h"
+
 
 namespace mxnet {
 
 DMLC_REGISTER_PARAMETER(CachedOpConfig);
 
-Imperative::CachedOp::CachedOp(
+struct CachedOp::GraphInfo {
+  nnvm::Graph fwd_graph;
+  nnvm::Graph full_graph;
+  std::vector<OpReqType> bwd_output_reqs;
+  std::vector<uint32_t> bwd_input_eid;
+};
+
+struct CachedOp::DynamicRuntime {
+  GraphInfo info;
+  std::vector<NDArray> buff;
+  std::vector<OpStatePtr> op_states;
+};
+
+struct CachedOp::CachedOpState {
+  CachedOpState(const Context& context_,
+                const nnvm::Graph& fwd_graph_,
+                const nnvm::Graph& full_graph_) {
+    context = context_;
+    info.fwd_graph = fwd_graph_;
+    info.full_graph = full_graph_;
+
+    size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
+    size_t max_entries = info.full_graph.indexed_graph().num_node_entries();
+    info.fwd_graph.attrs["context"] = std::make_shared<dmlc::any>(
+        std::vector<Context>(info.fwd_graph.indexed_graph().num_nodes(), 
context));
+    info.full_graph.attrs["context"] = std::make_shared<dmlc::any>(
+        std::vector<Context>(max_nodes, context));
+
+    buff.resize(max_entries);
+    arrays.resize(max_entries);
+    array_reqs.resize(max_entries);
+    dynamic_entries.resize(max_entries, false);
+    op_states.resize(max_nodes);
+    execs.resize(max_nodes);
+    opr_segs.resize(max_nodes);
+  }
+
+  std::mutex mutex;
+  Context context;
+  GraphInfo info;
+
+  bool recording = false;
+  bool fwd_alloc = false;
+  bool bwd_alloc = false;
+  bool fwd_exec_init = false;
+  bool bwd_exec_init = false;
+
+  std::vector<NDArray> buff;
+  std::vector<NDArray*> arrays;
+  std::vector<OpReqType> array_reqs;
+
+  std::vector<OpStatePtr> op_states;
+  std::vector<std::shared_ptr<exec::OpExecutor> > execs;
+  std::vector<imperative::EngineOprSeg> opr_segs;
+
+  std::vector<bool> dynamic_entries;
+  std::multimap<size_t, NDArray> fwd_reuse_pool;
+  std::multimap<size_t, NDArray> bwd_reuse_pool;
+};
+
+CachedOp::CachedOp(
     const nnvm::Symbol& sym,
-    const std::vector<std::pair<std::string, std::string> >& flags,
-    const std::vector<std::string> arg_names,
-    const std::unordered_map<std::string, std::vector<NDArray> >& params) {
+    const std::vector<std::pair<std::string, std::string> >& flags) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
@@ -36,6 +98,10 @@ Imperative::CachedOp::CachedOp(
 
   config_.Init(flags);
 
+  if (config_.static_shape) {
+    CHECK(config_.static_alloc) << "static_alloc must be True when 
static_shape is True";
+  }
+
   // construct forward graph
   {
     NodeEntryMap<int> dedup_out;
@@ -68,34 +134,22 @@ Imperative::CachedOp::CachedOp(
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
 
-    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= 
config_.inline_limit;
+    inlining_ = !config_.static_alloc &&
+        (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
   }
 
   // Set params
   {
     const auto& idx = fwd_graph_.indexed_graph();
-    std::unordered_map<std::string, size_t> arg_name_to_id;
-    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
-      const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
-      auto iter = params.find(name);
-      if (iter == params.end()) {
-        arg_name_to_id[name] = i;
-        continue;
-      }
-      fwd_params_idx_.push_back(i);
-      for (const auto& param : iter->second) {
-        params_[param.ctx()].emplace_back(param);
+    if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
+      CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
+               idx.input_nodes().size());
+    } else {
+      std::vector<uint32_t> tmp;
+      for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+        tmp.push_back(i);
       }
-    }
-
-    CHECK_EQ(arg_name_to_id.size(), arg_names.size())
-        << "CachedOp expects " << arg_name_to_id.size()
-        << " inputs, given " << arg_names.size();
-
-    for (const auto& name : arg_names) {
-      auto iter = arg_name_to_id.find(name);
-      CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
-      fwd_args_idx_.push_back(iter->second);
+      config_.data_indices.assign(tmp.begin(), tmp.end());
     }
   }
 
@@ -107,9 +161,14 @@ Imperative::CachedOp::CachedOp(
     }
 
     std::vector<NodeEntry> xs;
-    std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
-    xs.reserve(args.size());
-    for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0});
+    const auto& idx = fwd_graph_.indexed_graph();
+    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+      auto nid = idx.input_nodes()[i];
+      if (idx.mutable_input_nodes().count(nid)) continue;
+      fwd_input_to_grad_output_[i] = xs.size();
+      xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
+    }
+
     CHECK_GT(xs.size(), 0)
         << "There are no inputs in computation graph that require gradients.";
 
@@ -125,7 +184,7 @@ Imperative::CachedOp::CachedOp(
     size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
 
     full_graph_.outputs = fwd_graph_.outputs;
-    curr_grad_req_ = std::vector<bool>(grad_graph_.outputs.size(), true);
+    bwd_output_reqs_ = std::vector<OpReqType>(grad_graph_.outputs.size(), 
kWriteTo);
     for (const auto& i : grad_graph_.outputs) 
full_graph_.outputs.emplace_back(i);
     const auto& idx = full_graph_.indexed_graph();
 
@@ -169,7 +228,10 @@ Imperative::CachedOp::CachedOp(
   }
 }
 
-std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
+CachedOp::~CachedOp() {
+}
+
+std::vector<nnvm::NodeEntry> CachedOp::Gradient(
     const nnvm::NodePtr& node,
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
@@ -206,13 +268,15 @@ std::vector<nnvm::NodeEntry> 
Imperative::CachedOp::Gradient(
   return ret;
 }
 
-nnvm::Graph Imperative::CachedOp::GetForwardGraph(
-    const bool recording, const std::vector<NDArray*>& inputs) {
+
+bool CachedOp::SetForwardGraph(
+    GraphInfo* info,
+    const bool recording,
+    const std::vector<NDArray*>& inputs) {
   using namespace nnvm;
   using namespace imperative;
-  std::lock_guard<std::mutex> lock(mutex_);
   CHECK_EQ(inputs.size(), num_inputs());
-  nnvm::Graph& g = fwd_graph_;
+  nnvm::Graph& g = info->fwd_graph;
 
   ShapeVector shape_inputs;
   DTypeVector dtype_inputs;
@@ -237,18 +301,22 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
     g.attrs.erase("forward_mem_plan");
     g.attrs.erase("full_mem_plan");
   } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
-    return g;
+    return true;
   }
 
   const auto& idx = g.indexed_graph();
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
-  for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = 
exec::kExternalStorageID;
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
   CHECK_EQ(stypes.size(), storage.size());
   for (size_t i = 0; i < stypes.size(); i++) {
-    if (stypes[i] != kDefaultStorage)
-      storage[i] = exec::kDynamicStorageID;
+    if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
+  }
+  for (const auto i : idx.input_nodes()) {
+    storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
+  }
+  for (size_t i = 0; i < idx.outputs().size(); ++i) {
+    storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID;
   }
 
   auto mem_plan = PlanMemory(
@@ -257,51 +325,50 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
   g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
       std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return g;
+  return false;
 }
 
-nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
-    const OpStatePtr& op_state,
+bool CachedOp::SetBackwardGraph(
+    GraphInfo* info,
     const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& inputs) {
+    const std::vector<NDArray*>& inputs,
+    bool detect_inplace_addto) {
   using namespace nnvm;
   using namespace imperative;
   std::lock_guard<std::mutex> lock(mutex_);
-  nnvm::Graph& g = full_graph_;
-  auto& state = op_state.get_state<CachedOpState>();
-  bool req_match = true;
-  for (size_t i = 0; i < reqs.size(); ++i) {
-    if (curr_grad_req_[i] != (reqs[i] != kNullOp)) {
-      curr_grad_req_[i] = reqs[i] != kNullOp;
-      req_match = false;
-    }
-  }
-  if (!req_match) {
+  Context default_ctx = inputs[0]->ctx();
+  nnvm::Graph& g = info->full_graph;
+
+  if (info->bwd_output_reqs != reqs) {
+    info->bwd_output_reqs = reqs;
+    info->bwd_input_eid.clear();
     g = nnvm::Graph();
     g.outputs = fwd_graph_.outputs;
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
-      if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
+      if (info->bwd_output_reqs[i] == kNullOp) continue;
+      g.outputs.emplace_back(grad_graph_.outputs[i]);
     }
-    bwd_input_eid_.clear();
+    g.attrs["context"] = std::make_shared<dmlc::any>(
+        std::vector<Context>(g.indexed_graph().num_nodes(), default_ctx));
   }
 
   const auto& idx = g.indexed_graph();
 
-  if (bwd_input_eid_.size() != inputs.size()) {
-    bwd_input_eid_.clear();
+  if (info->bwd_input_eid.size() != inputs.size()) {
+    info->bwd_input_eid.clear();
     for (const auto& i : bwd_ograd_dep_) {
       auto eid = idx.entry_id(ograd_entries_[i]);
-      bwd_input_eid_.push_back(eid);
+      info->bwd_input_eid.push_back(eid);
     }
     for (const auto& i : bwd_in_dep_) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      bwd_input_eid_.push_back(eid);
+      info->bwd_input_eid.push_back(eid);
     }
     for (const auto& i : bwd_out_dep_) {
       auto eid = idx.entry_id(idx.outputs()[i]);
-      bwd_input_eid_.push_back(eid);
+      info->bwd_input_eid.push_back(eid);
     }
-    CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+    CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
   }
 
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -312,25 +379,22 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
     for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
       for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
     }
-    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]];
+    for (size_t i = 0; i < inputs.size(); ++i) 
++ref_count[info->bwd_input_eid[i]];
     for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
     g.attrs["backward_ref_count"] = 
std::make_shared<dmlc::any>(std::move(ref_count));
   }
 
-  ShapeVector shapes(idx.num_node_entries(), TShape());
-  DTypeVector dtypes(idx.num_node_entries(), -1);
-  StorageTypeVector stypes(idx.num_node_entries(), -1);
-
-  for (size_t i = 0; i < num_forward_entries; ++i) {
-    shapes[i] = state.buff[i].shape();
-    dtypes[i] = state.buff[i].dtype();
-    stypes[i] = state.buff[i].storage_type();
-  }
+  auto shapes = info->fwd_graph.GetAttr<ShapeVector>("shape");
+  shapes.resize(idx.num_node_entries(), TShape());
+  auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
+  dtypes.resize(idx.num_node_entries(), -1);
+  auto stypes = info->fwd_graph.GetAttr<StorageTypeVector>("storage_type");
+  stypes.resize(idx.num_node_entries(), -1);
 
   for (size_t i = 0; i < inputs.size(); ++i) {
-    shapes[bwd_input_eid_[i]] = inputs[i]->shape();
-    dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
-    stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
+    shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
+    dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
+    stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
   }
 
   std::pair<uint32_t, uint32_t> node_range, entry_range;
@@ -342,79 +406,353 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
                               node_range, entry_range);
   match &= CheckAndInferType(&g, std::move(dtypes), false,
                              node_range, entry_range);
-  exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
+  exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask());
   match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
                                     false, node_range, entry_range);
 
   if (!match) {
     g.attrs.erase("backward_mem_plan");
   } else if (g.attrs.count("backward_mem_plan")) {
-    return g;
+    return true;
   }
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
+  const auto& bwd_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  for (size_t i = 0; i < bwd_stypes.size(); i++) {
+    if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
+  }
   for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = 
exec::kExternalStorageID;
   for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = 
exec::kExternalStorageID;
   for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = 
exec::kExternalStorageID;
-  for (size_t i = 0; i < stypes.size(); i++) {
-    if (stypes[i] != kDefaultStorage)
-      storage[i] = exec::kDynamicStorageID;
-  }
 
   auto mem_plan = PlanMemory(
       &g, std::move(storage), g.GetAttr<std::vector<uint32_t> 
>("backward_ref_count"),
-      {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, 
idx.num_node_entries()});
+      {num_forward_nodes, idx.num_nodes()},
+      {num_forward_entries, idx.num_node_entries()},
+      detect_inplace_addto);
   g.attrs["backward_mem_plan"] = 
std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return g;
+  return false;
 }
 
-void Imperative::CachedOp::Forward(
-    const std::shared_ptr<CachedOp>& op_ptr,
-    const std::vector<NDArray*>& args,
-    const std::vector<NDArray*>& outputs) {
+OpStatePtr CachedOp::GetCachedOpState(
+    const Context& ctx) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  for (const auto& i : cached_op_states_[ctx]) {
+    // only create one state per device when not using static memory
+    if (!config_.static_alloc || i.unique()) {
+      return i;
+    }
+  }
+  auto state_ptr = OpStatePtr::Create<CachedOpState>(ctx, fwd_graph_, 
full_graph_);
+
+  cached_op_states_[ctx].push_back(state_ptr);
+  return state_ptr;
+}
+
+void CachedOp::StaticAllocMemory(
+    const OpStatePtr& state_ptr,
+    bool recording,
+    bool keep_fwd) {
   using namespace nnvm;
   using namespace imperative;
-  static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
-  CHECK_EQ(args.size(), fwd_args_idx_.size())
-      << "CachedOp requires " << fwd_args_idx_.size()
-      << " inputs but got " << args.size();
+  auto& state = state_ptr.get_state<CachedOpState>();
+  const auto& default_ctx = state.context;
+  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
+  const auto& idx = g.indexed_graph();
+  const auto& vstorage_inplace = g.GetAttr<std::vector<int> 
>("storage_inplace_index");
+  const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
+      keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : 
"forward_mem_plan"));
+  std::vector<int> addto_entry;
+  if (g.attrs.count("addto_entry")) {
+    addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
+  }
+  size_t start_eid =
+      keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0;
+  size_t end_eid = idx.num_node_entries();
+
+  if (!keep_fwd) state.fwd_alloc = false;
+  state.bwd_alloc = false;
+  for (size_t i = start_eid; i < state.buff.size(); ++i) {
+    state.buff[i] = NDArray();
+    state.arrays[i] = &state.buff[i];
+    state.array_reqs[i] = kNullOp;
+    state.dynamic_entries[i] = false;
+  }
+
+  for (auto i : idx.input_nodes()) {
+    auto eid = idx.entry_id(i, 0);
+    if (eid >= start_eid) state.dynamic_entries[eid] = true;
+  }
+  for (auto i : idx.outputs()) {
+    auto eid = idx.entry_id(i);
+    if (eid >= start_eid) state.dynamic_entries[eid] = true;
+  }
 
-  Context default_ctx = args[0]->ctx();
+  for (size_t i = start_eid; i < end_eid; ++i) {
+    if (addto_entry.size() && addto_entry[i]) {
+      state.array_reqs[i] = kAddTo;
+    } else if (vstorage_inplace[i] >= 0) {
+      state.array_reqs[i] = kWriteInplace;
+    } else if (vstorage_inplace[i] == -2) {
+      // -2 indicate that the entry is never referenced.
+      state.array_reqs[i] = kNullOp;
+    } else {
+      state.array_reqs[i] = kWriteTo;
+    }
+  }
 
+  auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
+  reuse_pool = imperative::AllocateMemory(
+      g, idx, default_ctx, start_eid, end_eid, mem_plan,
+      state.arrays, &state.array_reqs, std::move(reuse_pool));
 
-  std::vector<NDArray*> inputs(num_inputs());
-  for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
-    inputs[fwd_args_idx_[i]] = args[i];
+  state.recording = recording;
+  if (keep_fwd) {
+    state.bwd_alloc = true;
+  } else {
+    state.fwd_alloc = true;
   }
-  if (fwd_params_idx_.size()) {
-    CHECK(params_.find(default_ctx) != params_.end())
-        << "CachedOp is not initialized on context " << default_ctx;
+}
+
+void CachedOp::StaticInitExec(
+    const OpStatePtr& state_ptr,
+    bool recording,
+    bool keep_fwd) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  auto& state = state_ptr.get_state<CachedOpState>();
+  const auto& default_ctx = state.context;
+  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
+  const auto& idx = g.indexed_graph();
+  std::vector<int> skip_plus_node;
+  if (g.attrs.count("skip_plus_node")) {
+    skip_plus_node = g.GetAttr<std::vector<int> >("skip_plus_node");
+  }
+  size_t start_nid =
+      keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0;
+  size_t end_nid = idx.num_nodes();
+
+  if (!keep_fwd) state.fwd_exec_init = false;
+  state.bwd_exec_init = false;
 
-    for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
-      inputs[fwd_params_idx_[i]] = &params_[default_ctx][i];
+  for (size_t i = start_nid; i < state.execs.size(); ++i) {
+    state.execs[i].reset();
+    state.opr_segs[i] = EngineOprSeg();
+  }
+
+  if (!config_.static_shape) {
+    for (size_t i = start_nid; i < end_nid; ++i) {
+      state.opr_segs[i].next_nid = i + 1;
+      state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i];
+    }
+  } else {
+    for (size_t i = start_nid; i < end_nid; ++i) {
+      exec::CreateOpExecs(g, &state.execs, i);
     }
+    exec::AttachOpResources(g, state.execs, start_nid, end_nid);
+
+    for (size_t i = start_nid; i < end_nid; ++i) {
+      bool skip = idx[i].source->is_variable();
+      for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) {
+        skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])];
+      }
+      for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) {
+        skip = state.dynamic_entries[idx.entry_id(i, j)];
+      }
+      if (skip) continue;
+      SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
+    }
+
+    size_t bulk_size = idx.num_nodes();
+    std::unordered_set<uint32_t> excludes;
+    if (recording || keep_fwd) {
+      bulk_size = keep_fwd ? config_.backward_bulk_size : 
config_.forward_bulk_size;
+      for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
+      for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 
0));
+    }
+
+    CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, 
excludes,
+                      state.execs, skip_plus_node, &state.opr_segs);
   }
 
-  // Initialize
+  if (keep_fwd) {
+    state.bwd_exec_init = true;
+  } else {
+    state.fwd_exec_init = true;
+  }
+}
+
+void CachedOp::StaticRunOps(
+    const Context& default_ctx,
+    const nnvm::Graph& g,
+    const OpStatePtr& state_ptr,
+    size_t start_nid,
+    size_t end_nid) {
+  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+
+  bool profiling = profiler::Profiler::Get()->GetState() == 
profiler::Profiler::kRunning;
+  bool is_training = Imperative::Get()->is_training();
+  auto& state = state_ptr.get_state<CachedOpState>();
+  const auto& idx = g.indexed_graph();
+  const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+  const auto& op_execs = state.execs;
+
+  std::vector<NDArray*> ndinputs, ndoutputs;
+  nnvm::ShapeVector arg_shapes;
+  nnvm::DTypeVector arg_dtypes;
+  std::vector<OpReqType> req;
+
+  for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) {
+    if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training;
+  }
+
+  for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) {
+    const auto& opr_seg = state.opr_segs[i];
+    if (opr_seg.skip) continue;
+    if (opr_seg.opr != nullptr) {
+      Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling);
+    } else {
+      const nnvm::IndexedGraph::Node& node = idx[i];
+      if (node.source->is_variable()) continue;
+      auto num_outputs = node.source->num_outputs();
+      ndinputs.clear();
+      ndinputs.reserve(node.inputs.size());
+      for (const auto& j : node.inputs) {
+        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
+        CHECK(!ndinputs.back()->is_none());
+      }
+      ndoutputs.clear();
+      ndoutputs.reserve(num_outputs);
+      req.clear();
+      req.reserve(num_outputs);
+      for (size_t j = 0; j < num_outputs; ++j) {
+        size_t eid = idx.entry_id(i, j);
+        ndoutputs.emplace_back(state.arrays[eid]);
+        req.push_back(state.array_reqs[eid]);
+        CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
+      }
+      const DispatchMode dispatch_mode = dispatch_modes[i];
+      if (createop.count(node.source->op())) {
+        arg_shapes.clear();
+        arg_dtypes.clear();
+        arg_shapes.reserve(ndinputs.size());
+        arg_dtypes.reserve(ndinputs.size());
+        for (size_t i = 0; i < ndinputs.size(); ++i) {
+          arg_shapes.emplace_back(ndinputs[i]->shape());
+          arg_dtypes.emplace_back(ndinputs[i]->dtype());
+        }
+        state.op_states[i] = createop[node.source->op()](
+            node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
+        Imperative::Get()->InvokeOp(
+            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
+            dispatch_mode, state.op_states[i]);
+      } else if (is_layer_backward.get(node.source->op(), false)) {
+        nnvm::Node* fwd_node = node.source->control_deps[0].get();
+        auto fwd_node_id = idx.node_id(fwd_node);
+        Imperative::Get()->InvokeOp(
+            default_ctx, node.source->attrs, ndinputs, ndoutputs,
+            req, dispatch_mode, state.op_states[fwd_node_id]);
+      } else {
+        Imperative::Get()->InvokeOp(
+            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
+            dispatch_mode);
+      }
+    }
+  }
+}
+
+OpStatePtr CachedOp::StaticForward(
+    const Context& default_ctx,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+
   bool recording = Imperative::Get()->is_recording();
-  nnvm::Graph g = GetForwardGraph(recording, inputs);
+  auto state_ptr = GetCachedOpState(default_ctx);
+  auto& state = state_ptr.get_state<CachedOpState>();
+  std::lock_guard<std::mutex> lock(state.mutex);
+
+  bool match = SetForwardGraph(&state.info, recording, inputs);
+  match = match && state.recording == recording;
+
+  nnvm::Graph& g = state.info.fwd_graph;
   const auto& idx = g.indexed_graph();
-  size_t num_inputs = idx.input_nodes().size();
+  if (!state.fwd_alloc || !match)  {
+    StaticAllocMemory(state_ptr, recording, false);
+  }
 
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    CHECK_EQ(inputs[i]->ctx(), default_ctx)
-        << "CachedOp requires all inputs to live on the same context. But "
-        << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << 
default_ctx
-        << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is 
on "
-        << inputs[i]->ctx();
+  if (config_.static_shape) {
+    for (auto i : config_.param_indices) {
+      auto nid = idx.input_nodes()[i];
+      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
+        match = false;
+        auto ptr = &state.buff[idx.entry_id(nid, 0)];
+        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
+        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
+        state.dynamic_entries[idx.entry_id(nid, 0)] = false;
+      }
+    }
+    for (auto i : config_.data_indices) {
+      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+      state.arrays[eid] = inputs[i];
+    }
+  } else {
+    for (size_t i = 0; i < num_inputs(); ++i) {
+      auto nid = idx.input_nodes()[i];
+      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
+    }
   }
 
-  auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
-  auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
-  auto& buff = cached_op_state.buff;
-  auto& states = cached_op_state.states;
+  if (!state.fwd_exec_init || !match) {
+    StaticInitExec(state_ptr, recording, false);
+  }
+
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shapes = g.GetAttr<ShapeVector>("shape");
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    auto eid = idx.entry_id(idx.outputs()[i]);
+    state.arrays[eid] = outputs[i];
+    if (!outputs[i]->is_none()) continue;
+    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
+                          shapes[eid], default_ctx, true, dtypes[eid]);
+  }
+
+  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
+
+  return recording ? state_ptr : OpStatePtr();
+}
+
+
+OpStatePtr CachedOp::DynamicForward(
+    const Context& default_ctx,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  // Initialize
+  bool recording = Imperative::Get()->is_recording();
+  auto op_state = OpStatePtr::Create<DynamicRuntime>();
+  auto& runtime = op_state.get_state<DynamicRuntime>();
+  {
+    auto state_ptr = GetCachedOpState(default_ctx);
+    auto& state = state_ptr.get_state<CachedOpState>();
+    std::lock_guard<std::mutex> lock(state.mutex);
+    SetForwardGraph(&state.info, recording, inputs);
+    runtime.info.fwd_graph = state.info.fwd_graph;
+  }
+  nnvm::Graph& g = runtime.info.fwd_graph;
+  const auto& idx = g.indexed_graph();
+  size_t num_inputs = idx.input_nodes().size();
+  auto& buff = runtime.buff;
+  auto& states = runtime.op_states;
 
   // Allocate entries
   states.resize(idx.num_nodes());
@@ -446,57 +784,98 @@ void Imperative::CachedOp::Forward(
   AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shapes = g.GetAttr<ShapeVector>("shape");
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    auto eid = idx.entry_id(idx.outputs()[i]);
+    arrays[eid] = outputs[i];
+    if (!outputs[i]->is_none()) continue;
+    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
+                          shapes[eid], default_ctx, true, dtypes[eid]);
+  }
+
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
 
-  Imperative::Get()->RunGraph(
-      false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
-      std::move(ref_count), &states, dispatch_modes);
+  RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
+           std::move(ref_count), &states, dispatch_modes);
 
-  Engine::Get()->set_bulk_size(prev_bulk_size);
   Imperative::Get()->set_is_recording(recording);
 
-  for (size_t i = 0; i < idx.num_node_entries(); ++i) {
-    if (arrays[i] == &buff[i]) continue;
-    buff[i].shape_ = arrays[i]->shape_;
-    buff[i].dtype_ = arrays[i]->dtype_;
-    buff[i].storage_type_ = arrays[i]->storage_type_;
+  return op_state;
+}
+
+void CachedOp::Forward(
+    const std::shared_ptr<CachedOp>& op_ptr,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
+  static const auto cached_op = nnvm::Op::Get("_CachedOp");
+
+  CHECK_EQ(inputs.size(), num_inputs());
+
+  Context default_ctx = inputs[0]->ctx();
+
+  const auto& idx = fwd_graph_.indexed_graph();
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    CHECK_EQ(inputs[i]->ctx(), default_ctx)
+        << "CachedOp requires all inputs to live on the same context. But "
+        << idx[idx.input_nodes()[0]].source->attrs.name
+        << " is on " << default_ctx << " while "
+        << idx[idx.input_nodes()[i]].source->attrs.name
+        << " is on " << inputs[i]->ctx();
   }
 
-  if (recording && !inlining_) {
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
+
+  OpStatePtr op_state;
+  if (config_.static_alloc) {
+    op_state = StaticForward(default_ctx, inputs, outputs);
+  } else {
+    op_state = DynamicForward(default_ctx, inputs, outputs);
+  }
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+
+  if (Imperative::Get()->is_recording() && !inlining_) {
     nnvm::NodeAttrs attrs;
     attrs.op = cached_op;
     attrs.name = "_cachedop";
     attrs.parsed = op_ptr;
     Imperative::Get()->RecordOp(
-        std::move(attrs), inputs, outputs, op_state_ptr,
+        std::move(attrs), inputs, outputs, op_state,
         &save_inputs(), &save_outputs());
   }
 }
 
 
-void Imperative::CachedOp::Backward(
+void CachedOp::DynamicBackward(
     const bool retain_graph,
-    const OpStatePtr& state,
+    const OpStatePtr& op_state,
     const std::vector<NDArray*>& inputs,
     const std::vector<OpReqType>& reqs,
     const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
-  CHECK(!Imperative::Get()->is_recording())
-      << "CachedOp does not support higher order gradients. "
-      << "If you want to do backward with create_graph=True please "
-      << "do not use hybridize.";
 
   // Initialize
-  nnvm::Graph g = GetBackwardGraph(state, reqs, inputs);
+  Context default_ctx = outputs[0]->ctx();
+  auto& runtime = op_state.get_state<DynamicRuntime>();
+  {
+    auto state_ptr = GetCachedOpState(default_ctx);
+    auto& state = state_ptr.get_state<CachedOpState>();
+    std::lock_guard<std::mutex> lock(state.mutex);
+    state.info.fwd_graph = runtime.info.fwd_graph;
+    SetBackwardGraph(&state.info, reqs, inputs);
+    runtime.info.full_graph = state.info.full_graph;
+    runtime.info.bwd_input_eid = state.info.bwd_input_eid;
+  }
+  nnvm::Graph& g = runtime.info.full_graph;
   const auto& idx = g.indexed_graph();
-
-  auto& cached_op_state = state.get_state<CachedOpState>();
-  auto& buff = cached_op_state.buff;
-  auto& states = cached_op_state.states;
+  auto& buff = runtime.buff;
+  auto& states = runtime.op_states;
 
   size_t num_forward_outputs = fwd_graph_.outputs.size();
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -506,7 +885,7 @@ void Imperative::CachedOp::Backward(
   arrays.reserve(buff.size());
   for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
   for (size_t i = 0; i < inputs.size(); ++i) {
-    arrays[bwd_input_eid_[i]] = inputs[i];
+    arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
@@ -530,20 +909,14 @@ void Imperative::CachedOp::Backward(
     if (ref_count[i] == 0) array_reqs[i] = kNullOp;
   }
 
-  Context default_ctx = outputs[0]->ctx();
   const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
   AllocateMemory(g, idx, default_ctx, num_forward_entries, 
idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
-  int prev_bulk_size = 
Engine::Get()->set_bulk_size(config_.backward_bulk_size);
-
-  Imperative::Get()->RunGraph(
-      retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
-      std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
-
-  Engine::Get()->set_bulk_size(prev_bulk_size);
+  RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+           std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes);
 
   if (retain_graph) {
     buff.resize(num_forward_entries);
@@ -553,6 +926,99 @@ void Imperative::CachedOp::Backward(
   }
 }
 
+void CachedOp::StaticBackward(
+    const bool retain_graph,
+    const OpStatePtr& state_ptr,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<OpReqType>& reqs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  Context default_ctx = outputs[0]->ctx();
+
+  auto& state = state_ptr.get_state<CachedOpState>();
+  std::lock_guard<std::mutex> lock(state.mutex);
+
+  bool match = SetBackwardGraph(&state.info, reqs, inputs, true);
+
+  nnvm::Graph& g = state.info.full_graph;
+  const auto& idx = g.indexed_graph();
+  auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();
+
+  if (!state.bwd_alloc || !match) {
+    StaticAllocMemory(state_ptr, true, true);
+  }
+
+  if (config_.static_shape) {
+    for (auto i : config_.param_indices) {
+      const auto iter = fwd_input_to_grad_output_.find(i);
+      if (iter == fwd_input_to_grad_output_.end()) continue;
+      auto entry = grad_graph_.outputs[iter->second];
+      if (!idx.exist(entry.node.get())) continue;
+      auto eid = idx.entry_id(entry);
+      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
+          !(state.array_reqs[eid] == reqs[iter->second])) {
+        match = false;
+        state.array_reqs[eid] = reqs[iter->second];
+        *state.arrays[eid] = *outputs[iter->second];
+        state.dynamic_entries[eid] = false;
+      }
+    }
+    for (auto i : config_.data_indices) {
+      const auto iter = fwd_input_to_grad_output_.find(i);
+      if (iter == fwd_input_to_grad_output_.end()) continue;
+      auto entry = grad_graph_.outputs[iter->second];
+      if (!idx.exist(entry.node.get())) continue;
+      auto eid = idx.entry_id(entry);
+      state.array_reqs[eid] = reqs[iter->second];
+      state.arrays[eid] = outputs[iter->second];
+    }
+  } else {
+    for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
+      auto entry = grad_graph_.outputs[i];
+      if (!idx.exist(entry.node.get())) continue;
+      auto eid = idx.entry_id(entry);
+      state.array_reqs[eid] = reqs[i];
+      state.arrays[eid] = outputs[i];
+    }
+  }
+
+  if (!state.bwd_exec_init || !match) {
+    StaticInitExec(state_ptr, true, true);
+  }
+
+  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
+    auto eid = state.info.bwd_input_eid[i];
+    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
+  }
+
+  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
+}
+
+void CachedOp::Backward(
+    const bool retain_graph,
+    const OpStatePtr& state,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<OpReqType>& reqs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace imperative;
+  CHECK(!Imperative::Get()->is_recording())
+      << "CachedOp does not support higher order gradients. "
+      << "If you want to do backward with create_graph=True please "
+      << "do not use hybridize.";
+
+  int prev_bulk_size = 
Engine::Get()->set_bulk_size(config_.backward_bulk_size);
+
+  if (config_.static_alloc) {
+    StaticBackward(retain_graph, state, inputs, reqs, outputs);
+  } else {
+    DynamicBackward(retain_graph, state, inputs, reqs, outputs);
+  }
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+}
+
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
new file mode 100644
index 0000000..60a40c5
--- /dev/null
+++ b/src/imperative/cached_op.h
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
+#define MXNET_IMPERATIVE_CACHED_OP_H_
+
+#include <mxnet/imperative.h>
+#include <vector>
+#include <atomic>
+#include <utility>
+#include <string>
+#include <unordered_map>
+
+namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
+  uint32_t inline_limit;
+  uint32_t forward_bulk_size;
+  uint32_t backward_bulk_size;
+  bool static_alloc;
+  bool static_shape;
+  nnvm::Tuple<uint32_t> data_indices;
+  nnvm::Tuple<uint32_t> param_indices;
+  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
+    DMLC_DECLARE_FIELD(static_alloc)
+    .set_default(false)
+    .describe("Statically allocate memory to improve speed. "
+              "Memory usage may increase.");
+    DMLC_DECLARE_FIELD(static_shape)
+    .set_default(false)
+    .describe("Optimize for invariant input shapes between iterations. "
+              "Must also set static_alloc to True. "
+              "Change of input shapes is still allowed but slower.");
+    DMLC_DECLARE_FIELD(inline_limit)
+    .set_default(2)
+    .describe("Maximum number of operators that can be inlined.");
+    DMLC_DECLARE_FIELD(forward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during forward pass.");
+    DMLC_DECLARE_FIELD(backward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during backward pass.");
+    DMLC_DECLARE_FIELD(data_indices)
+    .set_default(nnvm::Tuple<uint32_t>())
+    .describe("Position of argument variables.");
+    DMLC_DECLARE_FIELD(param_indices)
+    .set_default(nnvm::Tuple<uint32_t>())
+    .describe("Position of parameters.");
+  }
+};
+
+class CachedOp {
+ public:
+  CachedOp(
+      const nnvm::Symbol& sym,
+      const std::vector<std::pair<std::string, std::string> >& flags);
+  ~CachedOp();
+  uint32_t num_inputs() {
+    return fwd_graph_.indexed_graph().input_nodes().size();
+  }
+  uint32_t num_outputs() {
+    return fwd_graph_.outputs.size();
+  }
+  uint32_t num_backward_inputs() {
+    return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
+  }
+  std::vector<bool>& save_inputs() {
+    return save_inputs_;
+  }
+  std::vector<bool>& save_outputs() {
+    return save_outputs_;
+  }
+  const std::unordered_set<uint32_t>& mutable_input_nodes() {
+    return fwd_graph_.indexed_graph().mutable_input_nodes();
+  }
+  std::vector<nnvm::NodeEntry> Gradient(
+      const nnvm::NodePtr& node,
+      const std::vector<nnvm::NodeEntry>& ograds);
+  void Forward(
+      const std::shared_ptr<CachedOp>& op_ptr,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<NDArray*>& outputs);
+  void Backward(
+      const bool retain_graph,
+      const OpStatePtr& state,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& outputs);
+
+ private:
+  struct GraphInfo;
+  struct DynamicRuntime;
+  struct CachedOpState;
+
+  OpStatePtr GetCachedOpState(const Context& ctx);
+  bool SetForwardGraph(
+      GraphInfo* info,
+      const bool recording,
+      const std::vector<NDArray*>& inputs);
+  bool SetBackwardGraph(
+      GraphInfo* info,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& inputs,
+      bool detect_inplace_addto = false);
+  OpStatePtr DynamicForward(
+      const Context& default_ctx,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<NDArray*>& outputs);
+  void DynamicBackward(
+      const bool retain_graph,
+      const OpStatePtr& op_state,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& outputs);
+  void StaticAllocMemory(
+      const OpStatePtr& state_ptr,
+      bool recording,
+      bool keep_fwd);
+  void StaticInitExec(
+      const OpStatePtr& state_ptr,
+      bool recording,
+      bool keep_fwd);
+  void StaticRunOps(
+      const Context& default_ctx,
+      const nnvm::Graph& g,
+      const OpStatePtr& state_ptr,
+      size_t start_nid,
+      size_t end_nid);
+  OpStatePtr StaticForward(
+      const Context& default_ctx,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<NDArray*>& outputs);
+  void StaticBackward(
+      const bool retain_graph,
+      const OpStatePtr& state_ptr,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& outputs);
+
+  CachedOpConfig config_;
+  nnvm::Graph fwd_graph_;
+  nnvm::Graph grad_graph_;
+  nnvm::Graph full_graph_;
+  bool inlining_;
+  std::vector<nnvm::NodeEntry> ograd_entries_;
+  std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+  std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
+  std::vector<bool> save_inputs_, save_outputs_;
+  std::vector<OpReqType> bwd_output_reqs_;
+
+  std::mutex mutex_;
+  std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
+};
+
+using CachedOpPtr = std::shared_ptr<CachedOp>;
+
+}  // namespace mxnet
+#endif  // MXNET_IMPERATIVE_CACHED_OP_H_
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 7caf305..e165425 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -19,6 +19,7 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
+#include "./cached_op.h"
 
 namespace mxnet {
 #if DMLC_CXX11_THREAD_LOCAL
@@ -266,95 +267,6 @@ void Imperative::RecordOp(
   }
 }
 
-void Imperative::RunGraph(
-    const bool retain_graph,
-    const nnvm::IndexedGraph& idx,
-    const std::vector<NDArray*> arrays,
-    size_t node_start, size_t node_end,
-    std::vector<OpReqType>&& array_reqs,
-    std::vector<uint32_t>&& ref_count,
-    std::vector<OpStatePtr> *p_states,
-    const DispatchModeVector &dispatch_modes) {
-  using namespace nnvm;
-  using namespace imperative;
-  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
-
-  std::vector<OpStatePtr>& states = *p_states;
-  bool recording = is_recording();
-
-  std::vector<NDArray*> ndinputs, ndoutputs;
-  ShapeVector arg_shapes;
-  DTypeVector arg_dtypes;
-  std::vector<OpReqType> req;
-
-  for (size_t i = node_start; i < node_end; ++i) {
-    const nnvm::IndexedGraph::Node& node = idx[i];
-    if (node.source->op() == nullptr) continue;
-    auto num_outputs = node.source->num_outputs();
-    ndinputs.clear();
-    ndinputs.reserve(node.inputs.size());
-    for (const auto& j : node.inputs) {
-      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
-      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name 
<< " " << j.index;
-    }
-    ndoutputs.clear();
-    ndoutputs.reserve(num_outputs);
-    req.clear();
-    req.reserve(num_outputs);
-    for (size_t j = 0; j < num_outputs; ++j) {
-      size_t eid = idx.entry_id(i, j);
-      ndoutputs.emplace_back(arrays[eid]);
-      req.push_back(array_reqs[eid]);
-      CHECK(!ndoutputs.back()->is_none());
-    }
-    const Context& ctx = ndoutputs[0]->ctx();
-    const DispatchMode dispatch_mode = dispatch_modes[i];
-    if (node.source->op() == bwd_cached_op) {
-      const auto& cached_op = 
dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, 
ndoutputs);
-    } else if (createop.count(node.source->op())) {
-      arg_shapes.clear();
-      arg_dtypes.clear();
-      arg_shapes.reserve(ndinputs.size());
-      arg_dtypes.reserve(ndinputs.size());
-      for (size_t i = 0; i < ndinputs.size(); ++i) {
-        arg_shapes.emplace_back(ndinputs[i]->shape());
-        arg_dtypes.emplace_back(ndinputs[i]->dtype());
-      }
-      states[i] = createop[node.source->op()](
-          node.source->attrs, ctx, arg_shapes, arg_dtypes);
-      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, 
dispatch_mode, states[i]);
-      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, 
ndoutputs, states[i]);
-    } else if (is_layer_backward.get(node.source->op(), false)) {
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
-               req, dispatch_mode, states[fwd_node_id]);
-      if (recording) {
-        RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, 
states[fwd_node_id]);
-      }
-    } else {
-      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, 
dispatch_mode);
-      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, 
ndoutputs);
-    }
-
-    for (const auto& j : node.inputs) {
-      size_t eid = idx.entry_id(j);
-      --ref_count[eid];
-      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
-    }
-    for (size_t j = 0; j < ndoutputs.size(); ++j) {
-      size_t eid = idx.entry_id(i, j);
-      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
-    }
-  }
-}
-
-
 std::vector<NDArray*> Imperative::Backward(
     const std::vector<NDArray*>& outputs,
     const std::vector<NDArray*>& ograds,
diff --git a/src/imperative/imperative_utils.cc 
b/src/imperative/imperative_utils.cc
new file mode 100644
index 0000000..464aefc
--- /dev/null
+++ b/src/imperative/imperative_utils.cc
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./imperative_utils.h"
+#include "./cached_op.h"
+
+namespace mxnet {
+namespace imperative {
+void RunGraph(
+    const bool retain_graph,
+    const nnvm::IndexedGraph& idx,
+    const std::vector<NDArray*> arrays,
+    size_t node_start, size_t node_end,
+    std::vector<OpReqType>&& array_reqs,
+    std::vector<uint32_t>&& ref_count,
+    std::vector<OpStatePtr> *p_states,
+    const DispatchModeVector &dispatch_modes) {
+  using namespace nnvm;
+  using namespace imperative;
+  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+  const auto imp = Imperative::Get();
+
+  std::vector<OpStatePtr>& states = *p_states;
+  bool recording = imp->is_recording();
+
+  std::vector<NDArray*> ndinputs, ndoutputs;
+  ShapeVector arg_shapes;
+  DTypeVector arg_dtypes;
+  std::vector<OpReqType> req;
+
+  for (size_t i = node_start; i < node_end; ++i) {
+    const nnvm::IndexedGraph::Node& node = idx[i];
+    if (node.source->op() == nullptr) continue;
+    auto num_outputs = node.source->num_outputs();
+    ndinputs.clear();
+    ndinputs.reserve(node.inputs.size());
+    for (const auto& j : node.inputs) {
+      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name 
<< " " << j.index;
+    }
+    ndoutputs.clear();
+    ndoutputs.reserve(num_outputs);
+    req.clear();
+    req.reserve(num_outputs);
+    for (size_t j = 0; j < num_outputs; ++j) {
+      size_t eid = idx.entry_id(i, j);
+      ndoutputs.emplace_back(arrays[eid]);
+      req.push_back(array_reqs[eid]);
+      CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
+    }
+    const Context& ctx = ndoutputs[0]->ctx();
+    const DispatchMode dispatch_mode = dispatch_modes[i];
+    if (node.source->op() == bwd_cached_op) {
+      const auto& cached_op = 
dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, 
ndoutputs);
+    } else if (createop.count(node.source->op())) {
+      arg_shapes.clear();
+      arg_dtypes.clear();
+      arg_shapes.reserve(ndinputs.size());
+      arg_dtypes.reserve(ndinputs.size());
+      for (size_t i = 0; i < ndinputs.size(); ++i) {
+        arg_shapes.emplace_back(ndinputs[i]->shape());
+        arg_dtypes.emplace_back(ndinputs[i]->dtype());
+      }
+      states[i] = createop[node.source->op()](
+          node.source->attrs, ctx, arg_shapes, arg_dtypes);
+      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, 
dispatch_mode, states[i]);
+      if (recording) {
+        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, 
states[i]);
+      }
+    } else if (is_layer_backward.get(node.source->op(), false)) {
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
+               req, dispatch_mode, states[fwd_node_id]);
+      if (recording) {
+        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, 
states[fwd_node_id]);
+      }
+    } else {
+      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, 
dispatch_mode);
+      if (recording) {
+        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
+      }
+    }
+
+    for (const auto& j : node.inputs) {
+      size_t eid = idx.entry_id(j);
+      --ref_count[eid];
+      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
+    }
+    for (size_t j = 0; j < ndoutputs.size(); ++j) {
+      size_t eid = idx.entry_id(i, j);
+      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
+    }
+  }
+}
+
+}  // namespace imperative
+}  // namespace mxnet
diff --git a/src/imperative/imperative_utils.h 
b/src/imperative/imperative_utils.h
index 06b7e05..726531d 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -23,6 +23,7 @@
 #include <utility>
 #include <algorithm>
 #include <vector>
+#include <map>
 #include <string>
 #include "../executor/graph_executor.h"
 #include "../executor/exec_pass.h"
@@ -38,11 +39,24 @@ namespace mxnet {
 namespace imperative {
 
 struct MemoryPlanInfo {
-  uint32_t sid;
+  int storage_id;
+  uint32_t root;
   size_t size;
   bool inplace;
 };
 
+struct EngineOprDeleter {
+  void operator()(engine::Opr* handle) {
+    Engine::Get()->DeleteOperator(handle);
+  }
+};
+
+struct EngineOprSeg {
+  bool skip;
+  size_t next_nid;
+  std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
+};
+
 using MemoryPlanVector = std::vector<MemoryPlanInfo>;
 
 inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -715,10 +729,12 @@ inline std::vector<Context> PlaceDevice(const 
nnvm::IndexedGraph& idx) {
 
 
 inline MemoryPlanVector PlanMemory(
-    nnvm::Graph* p_g, nnvm::StorageVector&& storage,
+    nnvm::Graph* p_g,
+    nnvm::StorageVector&& storage,
     const std::vector<uint32_t>& ref_count,
     const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
-    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
+    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
+    bool detect_inplace_addto = false) {
   using namespace nnvm;
   nnvm::Graph& g = *p_g;
   const auto& idx = g.indexed_graph();
@@ -728,31 +744,31 @@ inline MemoryPlanVector PlanMemory(
   g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
   g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
   g = nnvm::ApplyPass(g, "PlanMemory");
+  if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);
 
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
-  auto storage_inplace = g.MoveCopyAttr<std::vector<int> 
>("storage_inplace_index");
+  const auto& storage_inplace = g.GetAttr<std::vector<int> 
>("storage_inplace_index");
+  const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
   uint32_t entry_start = entry_range.first;
   uint32_t entry_end =
       entry_range.second > entry_start ? entry_range.second : 
idx.num_node_entries();
   MemoryPlanVector mem_plan(idx.num_node_entries());
-  std::unordered_map<int, uint32_t> sid_to_loc;
+  std::unordered_map<int, uint32_t> sid_to_root;
 
   for (uint32_t i = entry_start; i < entry_end; ++i) {
-    if (stypes[i] != kDefaultStorage) continue;
     if (storage_ids[i] < 0) {
-      mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), 
false};
-    } else if (!sid_to_loc.count(storage_ids[i])) {
+      mem_plan[i] = {storage_ids[i], i, 0, false};
+    } else if (!sid_to_root.count(storage_ids[i])) {
       CHECK_LT(storage_inplace[i], 0);
-      sid_to_loc[storage_ids[i]] = i;
-      mem_plan[i].sid = i;
-      mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
+      sid_to_root[storage_ids[i]] = i;
+      mem_plan[i] = {storage_ids[i], i,
+                     mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
+                     false};
     } else {
-      uint32_t loc = sid_to_loc[storage_ids[i]];
-      mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
-      mem_plan[loc].size = std::max(mem_plan[loc].size,
+      uint32_t root = sid_to_root[storage_ids[i]];
+      mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
+      mem_plan[root].size = std::max(mem_plan[root].size,
           mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
     }
   }
@@ -761,39 +777,213 @@ inline MemoryPlanVector PlanMemory(
 }
 
 
-inline void AllocateMemory(const nnvm::Graph& g,
-                    const nnvm::IndexedGraph& idx,
-                    const Context& default_ctx,
-                    const uint32_t entry_start, const uint32_t entry_end,
-                    const MemoryPlanVector& mem_plan,
-                    const std::vector<NDArray*>& arrays,
-                    std::vector<OpReqType> *array_reqs) {
+inline std::multimap<size_t, NDArray> AllocateMemory(
+    const nnvm::Graph& g,
+    const nnvm::IndexedGraph& idx,
+    const Context& default_ctx,
+    const uint32_t entry_start, const uint32_t entry_end,
+    const MemoryPlanVector& mem_plan,
+    const std::vector<NDArray*>& arrays,
+    std::vector<OpReqType> *array_reqs,
+    std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
   using namespace nnvm;
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
 
+  std::multimap<size_t, NDArray> new_pool;
+
   for (uint32_t i = entry_start; i < entry_end; ++i) {
-    if (!arrays[i]->is_none()) continue;
-    if (stypes[i] == kDefaultStorage) {
-      if (mem_plan[i].sid == i) {
-        CHECK_GT(mem_plan[i].size, 0);
+    if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
+    CHECK(arrays[i]->is_none());
+    if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
+      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
+                           shapes[i], default_ctx, true, dtypes[i]);
+      continue;
+    }
+    CHECK_EQ(stypes[i], kDefaultStorage);
+    if (mem_plan[i].root == i) {
+      CHECK_GT(mem_plan[i].size, 0);
+      auto iter = pool.lower_bound(mem_plan[i].size);
+      if (iter != pool.end()) {
+        *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
+        new_pool.insert(*iter);
+        pool.erase(iter);
+      } else {
         NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
                      default_ctx, true, mshadow::kUint8);
         *arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
+        new_pool.insert({mem_plan[i].size, buff});
+      }
+    } else {
+      CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
+      *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
+      if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
+        array_reqs->at(i) = kWriteInplace;
+      }
+    }
+  }
+
+  return new_pool;
+}
+
+inline void SetupOpExec(
+    const nnvm::Graph& g,
+    size_t nid,
+    const std::shared_ptr<exec::OpExecutor>& exec,
+    const std::vector<NDArray*> arrays,
+    const std::vector<OpReqType> array_reqs) {
+  const auto& idx = g.indexed_graph();
+  const auto& inode = idx[nid];
+  CHECK_EQ(exec->in_array.size(), 0U);
+  CHECK_EQ(exec->out_array.size(), 0U);
+  for (const auto& e : inode.inputs) {
+    CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
+    exec->in_array.push_back(*arrays[idx.entry_id(e)]);
+  }
+  for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
+    uint32_t eid = idx.entry_id(nid, index);
+    CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
+    exec->out_array.push_back(*arrays[eid]);
+    exec->req.push_back(array_reqs[eid]);
+  }
+
+  exec->Setup();
+}
+
+inline Engine::OprHandle CreateEngineOp(
+    const Context& default_ctx,
+    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
+  CHECK_GT(execs.size(), 0);
+  std::vector<Engine::VarHandle> use_vars, mutate_vars;
+
+  for (const auto& exec : execs) {
+    CHECK_GT(exec->out_array.size(), 0);
+    CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
+
+    // the variables
+    for (const auto& nd : exec->in_array) {
+      use_vars.push_back(nd.var());
+    }
+    for (auto& r : exec->op_ctx.requested) {
+      mutate_vars.push_back(r.var);
+    }
+    for (auto& nd : exec->out_array) {
+      mutate_vars.push_back(nd.var());
+    }
+    if (exec->var() != nullptr) {
+      mutate_vars.push_back(exec->var());
+    }
+  }
+
+  // dedup vars
+  Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
+  bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
+  bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == 
ExecType::kAsync;
+
+  auto exec_fun = [execs, is_async, is_gpu] (
+      RunContext ctx, Engine::CallbackOnComplete on_complete) {
+    if (is_async) {
+      execs[0]->op_ctx.async_on_complete = on_complete;
+    }
+    for (const auto& exec : execs) exec->Run(ctx, is_gpu);
+    // call on complete only if it is async op
+    if (!is_async) {
+      if (is_gpu) {
+      #if MXNET_USE_CUDA
+        // Wait GPU kernel to finish.
+        ctx.get_stream<gpu>()->Wait();
+      #else
+        LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+      #endif
+      }
+      on_complete();
+    }
+  };
+
+  return Engine::Get()->NewOperator(
+      exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
+}
+
+inline void CreateEngineOpSeg(
+    const nnvm::IndexedGraph& idx,
+    const Context default_ctx,
+    const size_t start_nid,
+    const size_t end_nid,
+    const size_t bulk_size,
+    const std::unordered_set<uint32_t>& excludes,
+    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
+    const std::vector<int> skip_plus_node,
+    std::vector<EngineOprSeg> *opr_segs) {
+  size_t seg_start = start_nid;
+  std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
+  for (size_t nid = start_nid; nid < end_nid; ++nid) {
+    const auto& node = idx[nid];
+    if (node.source->is_variable()) continue;
+    if (skip_plus_node.size() && skip_plus_node[nid]) continue;
+    auto& exec = execs[nid];
+    bool is_async = exec->exec_type() != ExecType::kSync;
+    bool valid = exec->out_array.size() > 0;
+
+    // Stop at async nodes and invalid node (due to input/output is not 
allocated)
+    bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
+    for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
+      if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
+    }
+    auto num_outputs = node.source->num_outputs();
+    for (size_t i = 0; i < num_outputs && !stop; ++i) {
+      if (excludes.count(idx.entry_id(nid, i))) stop = true;
+    }
+
+    // Create opr segment for previous nodes.
+    if (stop && nid > seg_start) {
+      auto& seg = (*opr_segs)[seg_start];
+      if (seg_execs.size()) {
+        seg = EngineOprSeg{false, nid};
+        seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
       } else {
-        *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
-        if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
-          array_reqs->at(i) = kWriteInplace;
-        }
+        seg = EngineOprSeg{true, nid, nullptr};
       }
+      seg_start = nid;
+      seg_execs.clear();
+    }
+
+    seg_execs.push_back(exec);
+
+    auto& seg = (*opr_segs)[nid];
+    if (is_async) {
+      seg = EngineOprSeg{false, nid + 1};
+      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
+      seg_execs.clear();
+      seg_start = nid + 1;
+    } else if (!valid) {
+      seg = EngineOprSeg{false, nid + 1, nullptr};
+      seg_execs.clear();
+      seg_start = nid + 1;
+    }
+  }
+  // The last segment
+  if (end_nid > seg_start) {
+    auto& seg = (*opr_segs)[seg_start];
+    if (seg_execs.size()) {
+      seg = EngineOprSeg{false, end_nid};
+      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
     } else {
-      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
-                           shapes[i], default_ctx, true, dtypes[i]);
+      seg = EngineOprSeg{true, end_nid, nullptr};
     }
   }
 }
 
+
+void RunGraph(const bool retain_graph,
+              const nnvm::IndexedGraph& idx,
+              const std::vector<NDArray*> arrays,
+              size_t node_start, size_t node_end,
+              std::vector<OpReqType>&& array_reqs,
+              std::vector<uint32_t>&& ref_count,
+              std::vector<OpStatePtr> *p_states,
+              const DispatchModeVector &dispatch_modes);
+
 }  // namespace imperative
 }  // namespace mxnet
 
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index e540657..6fafb36 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -22,6 +22,7 @@ from mxnet.test_utils import assert_almost_equal
 from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from common import setup_module, with_seed, assertRaises, teardown
 import numpy as np
+from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
 from copy import deepcopy
 import warnings
@@ -1132,7 +1133,6 @@ def test_hybrid_multi_context():
     net.hybridize()
     net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
 
-
 @with_seed()
 def test_zero_grad():
     data = mx.nd.random.uniform(shape=(3,3))
@@ -1145,6 +1145,60 @@ def test_zero_grad():
     grad = net.collect_params()['test_zero_grad_weight'].grad()
     assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
 
+def check_hybrid_static_memory(**kwargs):
+    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
+    x.attach_grad()
+
+    net1 = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, prefix='net_', 
ctx=mx.context.current_context())
+    net2 = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, prefix='net_', 
ctx=mx.context.current_context())
+    net2.hybridize(**kwargs)
+    net1(x)
+    net2(x)
+
+    def test(net, x):
+        with mx.autograd.record():
+            y = net(x) + net(x)
+            y.backward()
+
+        grads = {k: v.grad() for k, v in net.collect_params().items() if 
v.grad_req != 'null'}
+
+        return y, grads
+
+    y1, grads1 = test(net1, x)
+    y2, grads2 = test(net2, x)
+
+    assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
+    for key in grads1:
+        assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), 
rtol=1e-3, atol=1e-5)
+
+def test_hybrid_static_memory():
+    check_hybrid_static_memory()
+    check_hybrid_static_memory(static_alloc=True)
+    check_hybrid_static_memory(static_alloc=True, static_shape=True)
+
+def check_hybrid_static_memory_switching(**kwargs):
+    net = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, ctx=mx.context.current_context())
+    net.hybridize(**kwargs)
+
+    x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
+    net(x)
+    with mx.autograd.record():
+        y = net(x)
+        y.backward()
+    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
+    net(x)
+    with mx.autograd.record():
+        y = net(x)
+        y.backward()
+    mx.nd.waitall()
+
+def test_hybrid_static_memory_switching():
+    check_hybrid_static_memory_switching()
+    check_hybrid_static_memory_switching(static_alloc=True)
+    check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)
 
 @with_seed()
 def test_hook():
@@ -1239,6 +1293,17 @@ def test_legacy_save_params():
     model.load_params('test.params', ctx=mx.cpu())
 
 
+def test_hybrid_static_memory_recording():
+    net = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, ctx=mx.context.current_context())
+    net.hybridize(static_alloc=True)
+
+    x = mx.nd.random.uniform(shape=(1, 3, 32, 32))
+    with mx.autograd.record(True):
+        net(x)
+    net(x)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to