reminisce closed pull request #12104: [DO NOT REVIEW] Subgraph API
URL: https://github.com/apache/incubator-mxnet/pull/12104
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/subgraph_op/imagenet_inference.py 
b/example/subgraph_op/imagenet_inference.py
index 8a38cffc919..a0f16f67408 100644
--- a/example/subgraph_op/imagenet_inference.py
+++ b/example/subgraph_op/imagenet_inference.py
@@ -87,7 +87,8 @@ def score(sym, arg_params, aux_params, data, devs, 
label_name, max_num_examples,
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Score a model on a dataset')
-    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 
'imagenet1k-inception-bn'],
+    parser.add_argument('--model', type=str, required=True,
+                        choices=['imagenet1k-resnet-152', 
'imagenet1k-inception-bn'],
                         help='currently only supports imagenet1k-resnet-152 or 
imagenet1k-inception-bn')
     parser.add_argument('--batch-size', type=int, default=32)
     parser.add_argument('--label-name', type=str, default='softmax_label')
@@ -107,6 +108,8 @@ def score(sym, arg_params, aux_params, data, devs, 
label_name, max_num_examples,
                         help='shuffling seed, see'
                              ' 
https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
                              ' for more details')
+    parser.add_argument('--subgraph-backend', type=str, default='default', 
help='subgraph backend name.')
+    parser.add_argument('--ctx', type=str, default='cpu')
 
     args = parser.parse_args()
 
@@ -133,6 +136,15 @@ def score(sym, arg_params, aux_params, data, devs, 
label_name, max_num_examples,
     download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
     logger.info('Dataset for inference: %s' % dataset)
 
+    subgraph_backend = args.subgraph_backend
+
+    if args.ctx == 'cpu':
+        ctx = mx.cpu()
+    elif args.ctx == 'gpu':
+        ctx = mx.gpu(0)
+    else:
+        raise ValueError('unknown ctx option, only cpu and gpu are supported')
+
     # creating data iterator
     data = mx.io.ImageRecordIter(path_imgrec=dataset,
                                  label_width=1,
@@ -151,16 +163,21 @@ def score(sym, arg_params, aux_params, data, devs, 
label_name, max_num_examples,
     prefix, epoch = download_model(model_name=args.model, logger=logger)
     sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
     op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
-    out = SymbolHandle()
-    check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), 
c_str_array(op_names),
-                                     ctypes.byref(out)))
-    psym = Symbol(out)
-
+    if subgraph_backend is not None:
+        os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+        if subgraph_backend == 'default':
+            
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
+                                                         
c_str_array(op_names)))
     # make sure that fp32 inference works on the same images as calibrated 
quantized model
     logger.info('Skipping the first %d batches' % args.num_skipped_batches)
     data = advance_data_iter(data, args.num_skipped_batches)
 
     num_inference_images = args.num_inference_batches * batch_size
     logger.info('Running model %s for inference' % args.model)
-    score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+    score(sym, arg_params, aux_params, data, [ctx], label_name,
           max_num_examples=num_inference_images, logger=logger)
+
+    if subgraph_backend is not None:
+        del os.environ['MXNET_SUBGRAPH_BACKEND']
+        if subgraph_backend == 'default':
+            
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 2987cd75435..75147cfd706 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1479,11 +1479,6 @@ MXNET_DLL int 
MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
                                                const float* high_quantiles,
                                                SymbolHandle* ret_sym_handle);
 
-MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
-                               const mx_uint num_ops,
-                               const char** op_names,
-                               SymbolHandle* ret_sym_handle);
-
 //--------------------------------------------
 // Part 4: Executor interface
 //--------------------------------------------
diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h
new file mode 100644
index 00000000000..fe6fc7fe9cc
--- /dev/null
+++ b/include/mxnet/c_api_test.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file c_api_test.h
+ * \brief C API of mxnet for ease of testing backend in Python
+ */
+#ifndef MXNET_C_API_TEST_H_
+#define MXNET_C_API_TEST_H_
+
+/*! \brief Inhibit C++ name-mangling for MXNet functions. */
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+#include <mxnet/c_api.h>
+
+/*!
+ * \brief This API partitions a graph only by the operator names
+ * provided by users. This will attach a DefaultSubgraphProperty
+ * to the input graph for partitioning. This function should be
+ * used only for the testing purpose.
+ */
+MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+                                        const char* prop_name,
+                                        const mx_uint num_ops,
+                                        const char** op_names,
+                                        SymbolHandle* ret_sym_handle);
+
+/*!
+ * \brief Given a subgraph property name, use the provided op names
+ * as the op_names attribute for that subgraph property, instead of
+ * the predefined one. This is only for the purpose of testing.
+ */
+MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
+                                           const mx_uint num_ops,
+                                           const char** op_names);
+
+/*!
+ * \brief Given a subgraph property name, delete the op name set
+ * in the SubgraphPropertyOpNameSet.
+ */
+MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
+
+#ifdef __cplusplus
+}
+#endif  // __cplusplus
+
+#endif  // MXNET_C_API_TEST_H_
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 2c33b6cd16b..11e64edfcd5 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -43,7 +43,7 @@ class Engine;
 namespace engine {
 /*! \brief base class of engine variables.*/
 struct Var {
-  virtual uint32_t version() {
+  virtual size_t version() {
     return version_;
   }
   virtual ~Var() = default;
@@ -58,7 +58,7 @@ struct Var {
    * \brief version number of the var. Every time the object it is associated 
with
    * is modified, the version number is incremented by 1.
    */
-  uint32_t version_{0};
+  size_t version_{0};
 };  // struct Var
 
 /*! \brief Internal representation of operator.  */
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 7ed86ec7888..c27a59a67c6 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -31,7 +31,6 @@
 #include "./c_api_common.h"
 #include "../operator/operator_common.h"
 #include "../executor/exec_pass.h"
-#include "../operator/subgraph/default_subgraph_op.h"
 
 namespace mxnet {
 namespace op {
@@ -697,27 +696,3 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle 
qsym_handle,
   *ret_qsym_handle = s;
   API_END_HANDLE_ERROR(delete s);
 }
-
-int MXPartitionGraph(SymbolHandle sym_handle,
-                     const mx_uint num_ops,
-                     const char** op_names,
-                     SymbolHandle* ret_sym_handle) {
-  nnvm::Symbol* s = new nnvm::Symbol();
-  API_BEGIN();
-  std::unordered_set<std::string> op_name_set;
-  for (size_t i = 0; i < num_ops; ++i) {
-    op_name_set.emplace(op_names[i]);
-  }
-  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
-  *s = sym->Copy();
-  nnvm::Graph g = Symbol2Graph(*s);
-  if (!op_name_set.empty()) {
-    mxnet::op::SubgraphPropertyPtr property
-        = std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
-    g.attrs["subgraph_property"] = 
std::make_shared<nnvm::any>(std::move(property));
-  }
-  g = ApplyPass(std::move(g), "PartitionGraph");
-  s->outputs = g.outputs;
-  *ret_sym_handle = s;
-  API_END_HANDLE_ERROR(delete s);
-}
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
new file mode 100644
index 00000000000..2f5ad7611c4
--- /dev/null
+++ b/src/c_api/c_api_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file c_api_test.cc
+ * \brief C API of mxnet for the ease of testing backend in Python
+ */
+#include <mxnet/c_api_test.h>
+#include <nnvm/pass.h>
+#include "./c_api_common.h"
+#include "../operator/subgraph/default_subgraph_property.h"
+
+int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+                              const char* prop_name,
+                              const mx_uint num_ops,
+                              const char** op_names,
+                              SymbolHandle* ret_sym_handle) {
+  nnvm::Symbol* s = new nnvm::Symbol();
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+  *s = sym->Copy();
+  nnvm::Graph g;
+  g.outputs = s->outputs;
+  if (!op_name_set.empty()) {
+    mxnet::op::SubgraphPropertyPtr property
+        = 
mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+    property->SetAttr("op_names", op_name_set);
+    g.attrs["subgraph_property"] = 
std::make_shared<nnvm::any>(std::move(property));
+  }
+  g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
+  s->outputs = g.outputs;
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
+
+int MXSetSubgraphPropertyOpNames(const char* prop_name,
+                                 const mx_uint num_ops,
+                                 const char** op_names) {
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  (*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
+  API_END();
+}
+
+int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
+  API_BEGIN();
+  mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
+  API_END();
+}
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index e0a47fa9951..8adac9e30ff 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -86,10 +86,6 @@ class NaiveEngine final : public Engine {
   // new variables
   VarHandle NewVariable() override {
     return NaiveVar::New();
-#if 0
-    size_t v = ++counter_;
-    return reinterpret_cast<VarHandle>(v);
-#endif
   }
 
   OprHandle NewOperator(AsyncFn fn,
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index bd1169768eb..3a7587fef13 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -199,7 +199,7 @@ inline bool ThreadedVar::ready_to_read() {
   return this->is_ready_to_read();
 }
 
-inline uint32_t ThreadedVar::version() {
+inline size_t ThreadedVar::version() {
   std::lock_guard<std::mutex> lock{mutex_};
   return this->version_;
 }
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 7730c064b2b..a2c1a2b943a 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,7 +162,7 @@ class ThreadedVar final
   inline void SetToDelete();
   /*! \return whether this variable is ready to read. */
   inline bool ready_to_read();
-  inline uint32_t version() override;
+  inline size_t version() override;
   /*!
    * \brief Cast a Var pointer to ThreadedVar pointer
    * \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 33c6f574a04..4fc36b9d326 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -33,6 +33,7 @@
 #include "../profiler/profiler.h"
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
+#include "../operator/subgraph/subgraph_property.h"
 
 namespace mxnet {
 namespace exec {
@@ -40,6 +41,7 @@ namespace exec {
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
   need_grad_ = false;
+  subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
 }
 
 GraphExecutor::~GraphExecutor() {
@@ -1699,6 +1701,146 @@ GraphExecutor::CachedSegOpr 
GraphExecutor::CreateCachedSegOpr(size_t topo_start,
     iter->c_str());
   return ret;
 }
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+                                     nnvm::ShapeVector arg_shapes,
+                                     nnvm::DTypeVector arg_dtypes,
+                                     StorageTypeVector arg_stypes,
+                                     const Context& default_ctx,
+                                     const std::map<std::string, Context>& 
ctx_map,
+                                     const std::vector<Context>& in_arg_ctxes,
+                                     const std::vector<Context>& 
aux_state_ctxes) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto num_forward_inputs = indexed_graph.input_nodes().size();
+  g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+                   aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs, indexed_graph,
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs, indexed_graph,
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+  g = InferStorageType(std::move(g), std::move(arg_stypes), 
"__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+  return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal 
to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const nnvm::ShapeVector& arg_shapes,
+                                   const nnvm::DTypeVector& arg_dtypes,
+                                   const StorageTypeVector& arg_stypes,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& 
aux_state_ctxes) {
+  auto subgraph_prop = 
op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+  nnvm::Symbol ret = src.Copy();
+  nnvm::Graph g;
+  g.outputs = ret.outputs;
+  g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+                        ctx_map, in_arg_ctxes, aux_state_ctxes);
+  subgraph_prop->SetAttr("graph", g);
+  auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+  // assign a op name set to the subgraph property if it has been provided by 
users
+  if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+    LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << 
prop_name
+              << " has been assigned a value. Please make sure it is 
initialized"
+                 " only for the testing purpose.";
+    subgraph_prop->SetAttr("op_names", it->second);
+  }
+  g.attrs["subgraph_property"] = 
std::make_shared<nnvm::any>(std::move(subgraph_prop));
+  g = ApplyPass(std::move(g), "PartitionGraph");
+  ret.outputs = g.outputs;
+  return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to 
prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::unordered_map<std::string, 
TShape>& arg_shape_map,
+                                   const std::unordered_map<std::string, int>& 
arg_dtype_map,
+                                   const std::unordered_map<std::string, int>& 
arg_stype_map,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& 
aux_state_ctxes) {
+  const std::vector<std::string> input_names = 
src.ListInputNames(Symbol::kAll);
+  nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+  nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+  StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    auto it1 = arg_shape_map.find(input_names[i]);
+    if (arg_shape_map.end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map.find(input_names[i]);
+    if (arg_dtype_map.end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map.find(input_names[i]);
+    if (arg_stype_map.end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to 
prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::vector<NDArray> &in_args,
+                                   const std::vector<NDArray> &aux_states,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map) {
+  const std::vector<std::string> input_names = 
src.ListInputNames(Symbol::kAll);
+  const std::vector<std::string> arg_names = 
src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> aux_names = 
src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(arg_names.size(), in_args.size());
+  CHECK_EQ(aux_names.size(), aux_states.size());
+  nnvm::ShapeVector arg_shapes;  // all input shapes
+  arg_shapes.reserve(input_names.size());
+  nnvm::DTypeVector arg_dtypes;  // all input dtypes
+  arg_dtypes.reserve(input_names.size());
+  StorageTypeVector arg_stypes;  // all input stypes
+  arg_stypes.reserve(input_names.size());
+  std::vector<Context> in_arg_ctxes(in_args.size());
+  std::vector<Context> aux_state_ctxes(aux_states.size());
+
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+      arg_shapes.push_back(aux_states[i2].shape());
+      arg_dtypes.push_back(aux_states[i2].dtype());
+      arg_stypes.push_back(aux_states[i2].storage_type());
+      aux_state_ctxes[i2] = aux_states[i2].ctx();
+      ++i2;
+    } else {
+      CHECK(i1 < arg_names.size());
+      CHECK_EQ(arg_names[i1], input_names[i]);
+      arg_shapes.push_back(in_args[i1].shape());
+      arg_dtypes.push_back(in_args[i1].dtype());
+      arg_stypes.push_back(in_args[i1].storage_type());
+      in_arg_ctxes[i1] = in_args[i1].ctx();
+      ++i1;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
 }  // namespace exec
 
 Executor *Executor::SimpleBind(nnvm::Symbol symbol,
@@ -1718,6 +1860,11 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
                                std::unordered_map<std::string, NDArray>* 
shared_buffer,
                                Executor* shared_exec) {
   auto exec = new exec::GraphExecutor();
+  if (!exec->subgraph_property().empty()) {
+    symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), 
arg_shape_map, arg_dtype_map,
+                                  arg_stype_map, default_ctx, group2ctx, 
in_arg_ctxes,
+                                  aux_state_ctxes);
+  }
   exec->Init(symbol, default_ctx, group2ctx,
              in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
              arg_shape_map, arg_dtype_map, arg_stype_map,
@@ -1736,6 +1883,10 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
                          const std::vector<NDArray> &aux_states,
                          Executor* shared_exec) {
   auto exec = new exec::GraphExecutor();
+  if (!exec->subgraph_property().empty()) {
+    symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, 
aux_states,
+                                  default_ctx, group2ctx);
+  }
   exec->Init(symbol, default_ctx, group2ctx,
              in_args, arg_grad_store, grad_req_type, aux_states,
              reinterpret_cast<Executor*>(shared_exec));
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b4526..b4d36b14d5a 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -117,6 +117,8 @@ class GraphExecutor : public Executor {
                     std::vector<NDArray>* arg_grads,
                     std::vector<NDArray>* aux_states) override;
 
+  const std::string& subgraph_property() const { return subgraph_property_; }
+
  protected:
   friend class mxnet::Imperative;
   // Information about operational node
@@ -256,6 +258,8 @@ class GraphExecutor : public Executor {
   std::unordered_set<std::string> cached_seg_opr_names_;
   // verbose logging
   bool log_verbose_ = false;
+  // subgraph property name
+  std::string subgraph_property_;
 };
 
 }  // namespace exec
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index 472312d0a46..bf461048c11 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -57,22 +57,22 @@ struct SimpleNode {
 }  // namespace sg
 
 inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListInputNames(nnvm::Symbol::kAll).size();
 }
 
 inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListOutputNames().size();
 }
 
 inline std::vector<std::string> DefaultSubgraphOpListInputs(const 
nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListInputNames(nnvm::Symbol::kAll);
 }
 
 inline std::vector<std::string> DefaultSubgraphOpListOutputs(const 
nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListOutputNames();
 }
 
@@ -80,7 +80,7 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& 
attrs,
                                    std::vector<TShape> *in_shapes,
                                    std::vector<TShape> *out_shapes) {
   using namespace exec;
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -124,7 +124,7 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& 
attrs,
 inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
                                   std::vector<int> *in_types,
                                   std::vector<int> *out_types) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -169,7 +169,7 @@ inline bool DefaultSubgraphOpStorageType(const 
nnvm::NodeAttrs& attrs,
                                          DispatchMode* dispatch_mode,
                                          std::vector<int>* in_stypes,
                                          std::vector<int>* out_stypes) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -222,7 +222,7 @@ inline ExecType DefaultSubgraphOpExecType(const 
nnvm::NodeAttrs& attrs) {
 }
 
 inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const 
nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   const std::vector<std::string> input_names = 
subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
   const std::vector<std::string> immutable_input_names =
     subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
@@ -245,7 +245,7 @@ inline std::vector<uint32_t> 
DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr
 }
 
 inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const 
nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
   std::set<ResourceRequest::Type> resource_types;
   DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
diff --git a/src/operator/subgraph/default_subgraph_op.cc 
b/src/operator/subgraph/default_subgraph_op.cc
index 8372ae9326d..491d6ee9960 100644
--- a/src/operator/subgraph/default_subgraph_op.cc
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -18,7 +18,7 @@
 */
 
 #include <mxnet/ndarray.h>
-#include "./default_subgraph_op.h"
+#include "./common.h"
 #include "../../imperative/imperative_utils.h"
 #include "../../imperative/cached_op.h"
 
@@ -30,7 +30,8 @@ namespace op {
 class DefaultSubgraphOperator {
  public:
   explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
-    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
+    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
+                                            {"static_shape", "true"}}));
   }
 
   void Forward(const OpContext& ctx,
@@ -79,8 +80,7 @@ OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& 
attrs,
                                         Context ctx,
                                         const std::vector<TShape>& in_shapes,
                                         const std::vector<int>& in_types) {
-  const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
-  return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
+  return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
 }
 
 void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
diff --git a/src/operator/subgraph/default_subgraph_op.cu 
b/src/operator/subgraph/default_subgraph_op.cu
index 15a76e3bbb0..008826b21d7 100644
--- a/src/operator/subgraph/default_subgraph_op.cu
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -19,11 +19,14 @@
 
 /*!
  *  Copyright (c) 2018 by Contributors
- * \file subgraph_op.cu
+ * \file default_subgraph_op.cu
  * \brief GPU Implementation of subgraph operations
  */
 
-#include "./default_subgraph_op.h"
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/subgraph/default_subgraph_op.h 
b/src/operator/subgraph/default_subgraph_op.h
deleted file mode 100644
index 7d6624ef14d..00000000000
--- a/src/operator/subgraph/default_subgraph_op.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-
-#include <vector>
-#include <string>
-#include "./common.h"
-
-namespace mxnet {
-namespace op {
-
-/*
- * This provides criteria for selecting nodes in a subgraph.
- * When a node is passed to this object, the selection criteria may be changed.
- * We can also specify what links we should use when traversing the neighbor
- * nodes.
- */
-class SubgraphSelector {
- public:
-  virtual ~SubgraphSelector() {
-  }
-  // Determine if the node should be selected for a subgraph.
-  virtual bool Select(const nnvm::Node &n) = 0;
-  // Determine if the input node should be selected for a subgraph.
-  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 
0;
-  // Determine if the output node should be selected for a subgraph.
-  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 
0;
-  // Post processes pre-selected subgraph nodes. Return a list of nodes that
-  // users want to keep in subgraph(s).
-  virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
-                                          const std::vector<nnvm::Node*>& 
candidates) {
-    return candidates;
-  }
-};
-
-using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
-
-/*!
- * \brief This provides a set of properties for partitioning a graph into 
subgraphs,
- * reconstructing a new graph from the subgraphs and creating a subgraph
- * operator to execute the subgraph.
- */
-class SubgraphProperty {
- public:
-  // the criteria of selecting the subgraph nodes.
-  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
-  // create an nnvm node for a given subgraph. Here users can customize how to
-  // execute the operators in the subgraph.
-  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
-                                           const int subgraph_id = 0) const = 
0;
-};
-
-using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
-
-void RegisterSubgraphProperty(SubgraphPropertyPtr property);
-
-/*
- * This selects nodes for a subgraph that only contains operators
- * in a given set and it visits nodes via both input and output links.
- */
-class ContainOpSelector: public SubgraphSelector {
-  std::shared_ptr<const std::unordered_set<std::string>> op_names;
-
- public:
-  explicit ContainOpSelector(std::shared_ptr<const 
std::unordered_set<std::string>> op_names) {
-    this->op_names = op_names;
-  }
-
-  virtual bool Select(const nnvm::Node &n) {
-    return !n.is_variable() && op_names->count(n.op()->name);
-  }
-
-  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
-    return !new_node.is_variable() && op_names->count(new_node.op()->name);
-  }
-
-  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
-    return !new_node.is_variable() && op_names->count(new_node.op()->name);
-  }
-};
-
-/*
- * This subgraph property finds a subgraph whose nodes have only operators
- * within a set. The operators in the subgraph will be executed by 
_default_subgraph_op.
- */
-class DefaultSubgraphProperty: public SubgraphProperty {
- public:
-  explicit DefaultSubgraphProperty(const std::unordered_set<std::string> 
&op_names) :
-    op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
-  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
-                                           const int subgraph_id = 0) const {
-    nnvm::NodePtr n = nnvm::Node::Create();
-    n->attrs.op = Op::Get("_default_subgraph_op");
-    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
-    n->attrs.parsed = sym;
-    return n;
-  }
-  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
-    return std::make_shared<ContainOpSelector>(op_names_);
-  }
-
- private:
-  std::shared_ptr<const std::unordered_set<std::string>> op_names_;
-};
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
diff --git a/src/operator/subgraph/default_subgraph_property.h 
b/src/operator/subgraph/default_subgraph_property.h
new file mode 100644
index 00000000000..3882247dcd6
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_property.h
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
+
+#include <vector>
+#include <string>
+#include "./common.h"
+#include "./subgraph_property.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+ public:
+  explicit ContainOpSelector(const std::unordered_set<std::string>& op_names)
+    : op_names_(op_names) {}
+
+  virtual bool Select(const nnvm::Node &n) {
+    return !n.is_variable() && op_names_.count(n.op()->name);
+  }
+
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names_.count(new_node.op()->name);
+  }
+
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names_.count(new_node.op()->name);
+  }
+ private:
+  const std::unordered_set<std::string>& op_names_;
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by 
_default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+  static SubgraphPropertyPtr Create() { return 
std::make_shared<DefaultSubgraphProperty>(); }
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                           const int subgraph_id = 0) const {
+    nnvm::NodePtr n = nnvm::Node::Create();
+    n->attrs.op = Op::Get("_default_subgraph_op");
+    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+    return n;
+  }
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+    return std::make_shared<ContainOpSelector>(
+        this->GetAttr<std::unordered_set<std::string>>("op_names"));
+  }
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
diff --git a/src/operator/subgraph/partition_graph.cc 
b/src/operator/subgraph/partition_graph.cc
index 9672877eb1d..e8c3069255c 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -29,7 +29,7 @@
 #include <stack>
 #include <queue>
 
-#include "./default_subgraph_op.h"
+#include "./subgraph_property.h"
 #include "./common.h"
 
 namespace nnvm {
@@ -408,7 +408,7 @@ void FindSubgraphs(Graph* g,
                              &preselected_nodes);
 
       // filter out unqualified pre-selected nodes
-      std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g, 
preselected_nodes);
+      std::vector<nnvm::Node*> filtered_nodes = 
subgraph_selector->Filter(preselected_nodes);
 
       // make sure filtered_nodes is a subset of preselected_nodes
       for (const auto n : filtered_nodes) {
diff --git a/src/operator/subgraph/subgraph_property.h 
b/src/operator/subgraph/subgraph_property.h
new file mode 100644
index 00000000000..2153a366471
--- /dev/null
+++ b/src/operator/subgraph/subgraph_property.h
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+
+#include <nnvm/node.h>
+#include <dmlc/base.h>
+#include <dmlc/thread_local.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for selecting nodes in a subgraph.
+ * When a node is passed to this object, the selection criteria may be changed.
+ * We can also specify what links we should use when traversing the neighbor
+ * nodes.
+ */
+class SubgraphSelector {
+ public:
+  virtual ~SubgraphSelector() {}
+  // Determine if the node should be selected for a subgraph.
+  virtual bool Select(const nnvm::Node &n) = 0;
+  // Determine if the input node should be selected for a subgraph.
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 
0;
+  // Determine if the output node should be selected for a subgraph.
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 
0;
+  // Post processes pre-selected subgraph nodes. Return a list of nodes that
+  // users want to keep in subgraph(s).
+  virtual std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& 
candidates) {
+    return candidates;
+  }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into 
subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+  // the criteria of selecting the subgraph nodes.
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+  // create an nnvm node for a given subgraph. Here users can customize how to
+  // execute the operators in the subgraph.
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+                                           const int subgraph_id = 0) const = 
0;
+  // set an attr with name in the attr map
+  template<typename T>
+  SubgraphProperty& SetAttr(const std::string& name, const T& value) {
+    attrs_[name] = std::make_shared<dmlc::any>(value);
+    return *this;
+  }
+  // get the attr with the name
+  template<typename T>
+  const T& GetAttr(const std::string& name) const {
+    auto it = attrs_.find(name);
+    CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in 
SubgraphProperty";
+    return nnvm::get<T>(*it->second);
+  }
+ protected:
+  std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+class SubgraphPropertyRegistry {
+ public:
+  typedef SubgraphPropertyPtr (*SubgraphPropertyCreateFn)(void);
+  static SubgraphPropertyRegistry* Get() {
+    static SubgraphPropertyRegistry inst;
+    return &inst;
+  }
+
+  SubgraphPropertyPtr CreateSubgraphProperty(const std::string& name) {
+    auto it = prop_fn_map_.find(name);
+    CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name
+                                    << " is not found in 
SubgraphPropertyRegistry";
+    return it->second();
+  }
+
+  SubgraphPropertyCreateFn __REGISTER__(const std::string& name, 
SubgraphPropertyCreateFn fn) {
+    CHECK_EQ(prop_fn_map_.count(name), 0U) << "Subgraph property " << name
+                                           << " has been registered";
+    prop_fn_map_[name] = fn;
+    return prop_fn_map_[name];
+  }
+
+ private:
+  SubgraphPropertyRegistry() = default;
+  SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete;
+  SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete;
+  SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) = 
delete;
+  std::unordered_map<std::string, SubgraphPropertyCreateFn> prop_fn_map_;
+};
+
+// This op name set is for setting the names of operators that should be 
grouped into
+// subgraphs. In practice, every backend accelerator should have a predefined 
name set.
+// This set is only used for the testing purpose.
+// key: property name, value: op name set
+typedef dmlc::ThreadLocalStore<std::unordered_map<std::string, 
std::unordered_set<std::string>>>
+  SubgraphPropertyOpNameSet;
+
+#define MXNET_REGISTER_SUBGRAPH_PROPERTY(Name, SubgraphPropertyType) \
+  static DMLC_ATTRIBUTE_UNUSED auto __make_ ## SubgraphPropertyType ## _ ## 
Name ## __ = \
+    SubgraphPropertyRegistry::Get()->__REGISTER__(#Name, 
&SubgraphPropertyType::Create)
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
diff --git a/tests/cpp/engine/threaded_engine_test.cc 
b/tests/cpp/engine/threaded_engine_test.cc
index 92d0958c463..6d669c19bca 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -275,6 +275,64 @@ TEST(Engine, basics) {
   LOG(INFO) << "All pass";
 }
 
+TEST(Engine, VarVersion) {
+  const size_t num_engines = 3;
+  std::vector<mxnet::Engine*> engines(num_engines);
+  engines[0] = mxnet::engine::CreateNaiveEngine();
+  engines[1] = mxnet::engine::CreateThreadedEnginePooled();
+  engines[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+  std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", 
"ThreadedEnginePerDevice"};
+  for (size_t k = 0; k < num_engines; ++k) {
+    auto engine = engines[k];
+    std::vector<mxnet::Engine::OprHandle> oprs;
+
+    LOG(INFO) << "Testing var as a read dependency in " << type_names[k];
+    auto var = engine->NewVariable();
+    EXPECT_EQ(var->version(), 0U);
+    for (int i = 0; i < 10; ++i) {
+      oprs.push_back(engine->NewOperator(
+          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+            Foo(ctx, i);
+            cb();
+          },
+          {var}, {}));
+      engine->Push(oprs.at(i), mxnet::Context{});
+    }
+    engine->WaitForAll();
+    EXPECT_EQ(var->version(), 0U);
+    for (auto&& i : oprs) {
+      engine->DeleteOperator(i);
+    }
+    engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+    engine->WaitForAll();
+
+    LOG(INFO) << "Testing var as a write dependency in " << type_names[k];
+    var = engine->NewVariable();
+    EXPECT_EQ(var->version(), 0U);
+    oprs.clear();
+    for (int i = 0; i < 10; ++i) {
+      oprs.push_back(engine->NewOperator(
+          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+            Foo(ctx, i);
+            cb();
+          },
+          {}, {var}));
+      engine->Push(oprs.at(i), mxnet::Context{});
+    }
+    engine->WaitForAll();
+    EXPECT_EQ(var->version(), 10U);
+    for (auto&& i : oprs) {
+      engine->DeleteOperator(i);
+    }
+    engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+    engine->WaitForAll();
+
+    var = nullptr;
+    oprs.clear();
+    LOG(INFO) << "All pass";
+  }
+}
+
 #ifdef _OPENMP
 
 struct TestSaveAndRestoreOMPState {
diff --git a/tests/python/unittest/test_gluon_trainer.py 
b/tests/python/unittest/test_gluon_trainer.py
index 13e8e4e4b81..2a34400d60a 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -175,7 +175,6 @@ def test_trainer_save_load():
     # check if parameter dict is correctly associated with optimizer after 
load_state
     assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
 
-@unittest.skip("temporarily disabled till it gets fixed. tracked at 
https://github.com/apache/incubator-mxnet/issues/11353";)
 @with_seed()
 def test_trainer_reset_kv():
     def check_trainer_reset_kv(kv):
diff --git a/tests/python/unittest/test_subgraph_op.py 
b/tests/python/unittest/test_subgraph_op.py
index f6a33c244a7..40d609ad354 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -15,26 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
 import ctypes
 import mxnet as mx
-from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, 
c_str
 from mxnet.symbol import Symbol
 import numpy as np
 from mxnet.test_utils import assert_almost_equal
 
 
 def test_subgraph_exe():
-    def check_subgraph_exe(sym, op_names):
+    def _check_subgraph_exe1(sym, op_names):
+        """Use the partitioned sym to simple_bind an executor and compare the 
outputs
+        with those of the original executor"""
         out = SymbolHandle()
-        check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)),
-                                         c_str_array(op_names), 
ctypes.byref(out)))
+        check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, 
c_str('default'), mx_uint(len(op_names)),
+                                                  c_str_array(op_names), 
ctypes.byref(out)))
 
         partitioned_sym = Symbol(out)
         assert partitioned_sym.list_inputs() == sym.list_inputs()
         assert partitioned_sym.list_arguments() == sym.list_arguments()
         assert partitioned_sym.list_auxiliary_states() == 
sym.list_auxiliary_states()
-        exe = sym.simple_bind(ctx=mx.cpu(), grad_req='null')
-        partitioned_exe = partitioned_sym.simple_bind(ctx=mx.cpu(), 
grad_req='null')
+        exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+        partitioned_exe = 
partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
         input_names = sym.list_inputs()
         for name in input_names:
             if name in exe.arg_dict:
@@ -46,12 +49,109 @@ def check_subgraph_exe(sym, op_names):
                 partitioned_exe.aux_dict[name][:] = exe.aux_dict[name]
         exe.forward()
         partitioned_exe.forward()
-        mx.nd.waitall()
         assert len(exe.outputs) == len(partitioned_exe.outputs)
         for i in range(len(exe.outputs)):
             assert_almost_equal((exe.outputs[i] - 
partitioned_exe.outputs[i]).abs().sum().asnumpy(),
                                 np.zeros(shape=(1,)))
 
+    def _check_subgraph_exe2(sym, op_names):
+        """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph 
partitioning in simple_bind
+        and compare results of the partitioned sym and the original sym."""
+        def get_executor(sym, subgraph_backend=None, op_names=None, 
original_exec=None):
+            if subgraph_backend is not None:
+                os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+                
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
+                                                             
c_str_array(op_names)))
+            exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+            input_names = sym.list_inputs()
+            for name in input_names:
+                if name in exe.arg_dict:
+                    exe.arg_dict[name][:] = 
mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
+                        if original_exec is None else 
original_exec.arg_dict[name]
+                else:
+                    assert name in exe.aux_dict
+                    exe.aux_dict[name][:] = 
mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
+                        if original_exec is None else 
original_exec.aux_dict[name]
+            exe.forward()
+            if subgraph_backend is not None:
+                
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+                del os.environ['MXNET_SUBGRAPH_BACKEND']
+            return exe
+
+        original_exec = get_executor(sym)
+        partitioned_exec = get_executor(sym, 'default', op_names, 
original_exec)
+        outputs1 = original_exec.outputs
+        outputs2 = partitioned_exec.outputs
+        assert len(outputs1) == len(outputs2)
+        for i in range(len(outputs1)):
+            assert_almost_equal((outputs1[i] - 
outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe3(sym, op_names):
+        """Use the partitioned sym to bind an executor and compare the outputs
+        with those of the original executor"""
+        out = SymbolHandle()
+        check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, 
c_str('default'), mx_uint(len(op_names)),
+                                                  c_str_array(op_names), 
ctypes.byref(out)))
+
+        partitioned_sym = Symbol(out)
+        input_names = sym.list_inputs()
+        arg_names = sym.list_arguments()
+        aux_names = sym.list_auxiliary_states()
+        assert partitioned_sym.list_inputs() == input_names
+        assert partitioned_sym.list_arguments() == arg_names
+        assert partitioned_sym.list_auxiliary_states() == aux_names
+        arg_shapes, _, aux_shapes = sym.infer_shape()
+        arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+        aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+        exe = sym.bind(ctx=mx.current_context(), args=arg_array, 
aux_states=aux_array, grad_req='null')
+        partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(), 
args=arg_array,
+                                               aux_states=aux_array, 
grad_req='null')
+        exe.forward()
+        partitioned_exe.forward()
+        assert len(exe.outputs) == len(partitioned_exe.outputs)
+        for i in range(len(exe.outputs)):
+            assert_almost_equal((exe.outputs[i] - 
partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+                                np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe4(sym, op_names):
+        """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph 
partitioning in bind
+        and compare results of the partitioned sym and the original sym."""
+        def get_executor(sym, subgraph_backend=None, op_names=None, 
original_exec=None):
+            if subgraph_backend is not None:
+                os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+                
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
+                                                             
c_str_array(op_names)))
+            arg_shapes, _, aux_shapes = sym.infer_shape()
+            if subgraph_backend is None:
+                arg_array = [mx.nd.random.uniform(shape=shape) for shape in 
arg_shapes]
+                aux_array = [mx.nd.random.uniform(shape=shape) for shape in 
aux_shapes]
+            else:
+                arg_array = None
+                aux_array = None
+            exe = sym.bind(ctx=mx.current_context(),
+                           args=arg_array if subgraph_backend is None else 
original_exec.arg_arrays,
+                           aux_states=aux_array if subgraph_backend is None 
else original_exec.aux_arrays,
+                           grad_req='null')
+            exe.forward()
+            if subgraph_backend is not None:
+                
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+                del os.environ['MXNET_SUBGRAPH_BACKEND']
+            return exe
+
+        original_exec = get_executor(sym)
+        partitioned_exec = get_executor(sym, 'default', op_names, 
original_exec)
+        outputs1 = original_exec.outputs
+        outputs2 = partitioned_exec.outputs
+        assert len(outputs1) == len(outputs2)
+        for i in range(len(outputs1)):
+            assert_almost_equal((outputs1[i] - 
outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+    def check_subgraph_exe(sym, op_names):
+        _check_subgraph_exe1(sym, op_names)
+        _check_subgraph_exe2(sym, op_names)
+        _check_subgraph_exe3(sym, op_names)
+        _check_subgraph_exe4(sym, op_names)
+
     def test_network_structure_1():
         data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
         data2 = mx.sym.var('data2')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to