This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a64cf7d Subgraph API for integrating accelerators with MXNet (#12157)
a64cf7d is described below
commit a64cf7d9c8c1c473e201b5bd68ab9af6bf7365ba
Author: reminisce <[email protected]>
AuthorDate: Thu Aug 30 19:13:33 2018 -0700
Subgraph API for integrating accelerators with MXNet (#12157)
* Graph partitioner and subgraph op (#11251)
Graph partitioner and subgraph op
Fix duplicate entry bugs (#11767)
Make subgraph var node name unique (#11876)
[DO NOT REVIEW] Fix bug of eliminating cycles (#11907)
* Fix cycle bug
* Fix decycle bug
* Fix comment
[DO NOT REVIEW] Subgraph API (#12104)
* Initial commit
* Add unit tests
* Fix lint
* Fix lint
* Clean up
* Add graph partitiong to Bind
* Add property name to graph partitioning c api
* Fix unit test gpu context
* Address cr
* Move subgraph to attrs.subgraphs and fix the example
* Fix lint
* Add var version unit test
* Address cr
* Enable unit test that was flaky
* Clean up
* Clean up
* Clean up
* Change version return type in NDArray
* Clean up
* Add register or get for subgraph prop registry
* Address cr
* Remove unnecessary code
* Handle var version issue in naive engine
* Delete example
* Remove registration of resource request for default subgraph op
* Add doc string
* Improve doc string
---
include/mxnet/c_api_test.h | 66 ++
include/mxnet/engine.h | 22 +-
include/mxnet/ndarray.h | 4 +
src/c_api/c_api_test.cc | 73 ++
src/engine/engine_impl.h | 14 -
src/engine/naive_engine.cc | 31 +-
src/engine/threaded_engine.cc | 10 +-
src/engine/threaded_engine.h | 1 +
src/executor/graph_executor.cc | 151 ++++
src/executor/graph_executor.h | 4 +
src/operator/subgraph/common.h | 237 +++++++
src/operator/subgraph/default_subgraph_op.cc | 112 +++
src/operator/subgraph/default_subgraph_op.cu | 44 ++
src/operator/subgraph/default_subgraph_property.cc | 76 ++
src/operator/subgraph/partition_graph.cc | 774 +++++++++++++++++++++
src/operator/subgraph/subgraph_property.h | 166 +++++
tests/cpp/engine/threaded_engine_test.cc | 58 ++
tests/python/gpu/test_operator_gpu.py | 1 +
tests/python/unittest/test_subgraph_op.py | 238 +++++++
19 files changed, 2059 insertions(+), 23 deletions(-)
diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h
new file mode 100644
index 0000000..fe6fc7f
--- /dev/null
+++ b/include/mxnet/c_api_test.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file c_api_test.h
+ * \brief C API of mxnet for ease of testing backend in Python
+ */
+#ifndef MXNET_C_API_TEST_H_
+#define MXNET_C_API_TEST_H_
+
+/*! \brief Inhibit C++ name-mangling for MXNet functions. */
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+#include <mxnet/c_api.h>
+
+/*!
+ * \brief This API partitions a graph only by the operator names
+ * provided by users. This will attach a DefaultSubgraphProperty
+ * to the input graph for partitioning. This function should be
+ * used only for the testing purpose.
+ */
+MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+ const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names,
+ SymbolHandle* ret_sym_handle);
+
+/*!
+ * \brief Given a subgraph property name, use the provided op names
+ * as the op_names attribute for that subgraph property, instead of
+ * the predefined one. This is only for the purpose of testing.
+ */
+MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names);
+
+/*!
+ * \brief Given a subgraph property name, delete the op name set
+ * in the SubgraphPropertyOpNameSet.
+ */
+MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
+
+#ifdef __cplusplus
+}
+#endif // __cplusplus
+
+#endif // MXNET_C_API_TEST_H_
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index dc48bfb..11e64ed 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -41,8 +41,26 @@ class Engine;
/*! \brief namespace of engine internal types. */
namespace engine {
-/*! \brief Internal representation of variable. */
-struct Var;
+/*! \brief base class of engine variables.*/
+struct Var {
+ virtual size_t version() {
+ return version_;
+ }
+ virtual ~Var() = default;
+ /*!
+ * \brief cast variable to derived type T
+ * \tparam T the type we want to cast into.
+ * \return A casted variable.
+ */
+ template <typename T>
+ inline T* Cast();
+ /*!
+ * \brief version number of the var. Every time the object it is associated
with
+ * is modified, the version number is incremented by 1.
+ */
+ size_t version_{0};
+}; // struct Var
+
/*! \brief Internal representation of operator. */
struct Opr;
/*! \brief Variable pointer type, usually hold by user used to specify
dependencies. */
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index bae3ea9..6141a4d 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -340,6 +340,10 @@ class NDArray {
inline size_t byte_offset() const {
return byte_offset_;
}
+ /*! \brief return var version of the NDArray*/
+ inline size_t version() const {
+ return var()->version();
+ }
/*!
* \brief save the content into binary stream
* \param strm the output stream
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
new file mode 100644
index 0000000..623faa7
--- /dev/null
+++ b/src/c_api/c_api_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file c_api_test.cc
+ * \brief C API of mxnet for the ease of testing backend in Python
+ */
+#include <mxnet/c_api_test.h>
+#include <nnvm/pass.h>
+#include "./c_api_common.h"
+#include "../operator/subgraph/subgraph_property.h"
+
+int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+ const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names,
+ SymbolHandle* ret_sym_handle) {
+ nnvm::Symbol* s = new nnvm::Symbol();
+ API_BEGIN();
+ std::unordered_set<std::string> op_name_set;
+ for (size_t i = 0; i < num_ops; ++i) {
+ op_name_set.emplace(op_names[i]);
+ }
+ nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+ *s = sym->Copy();
+ nnvm::Graph g;
+ g.outputs = s->outputs;
+ if (!op_name_set.empty()) {
+ mxnet::op::SubgraphPropertyPtr property
+ =
mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+ property->SetAttr("op_names", op_name_set);
+ g.attrs["subgraph_property"] =
std::make_shared<nnvm::any>(std::move(property));
+ }
+ g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
+ s->outputs = g.outputs;
+ *ret_sym_handle = s;
+ API_END_HANDLE_ERROR(delete s);
+}
+
+int MXSetSubgraphPropertyOpNames(const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names) {
+ API_BEGIN();
+ std::unordered_set<std::string> op_name_set;
+ for (size_t i = 0; i < num_ops; ++i) {
+ op_name_set.emplace(op_names[i]);
+ }
+ (*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
+ API_END();
+}
+
+int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
+ API_BEGIN();
+ mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
+ API_END();
+}
diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h
index b3ec34d..f15141f 100644
--- a/src/engine/engine_impl.h
+++ b/src/engine/engine_impl.h
@@ -33,20 +33,6 @@
namespace mxnet {
namespace engine {
-/*! \brief base class of engine variables, used for type checking */
-struct Var {
-#if ENGINE_DEBUG
- virtual ~Var() = default;
-#endif // ENGINE_DEBUG
- /*!
- * \brief cast variable to derived type T
- * \tparam T the type we want to cast into.
- * \return A casted variable.
- */
- template <typename T>
- inline T* Cast();
-}; // struct Var
-
/*! \brief base class of engine operators, used for type checking */
struct Opr {
#if ENGINE_DEBUG
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 8196af2..daff530 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -28,10 +28,24 @@
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
+#include "../common/object_pool.h"
namespace mxnet {
namespace engine {
+/*!
+ * \brief var used in Naive Engine for tracking the version
+ * of the objects it is associated with.
+ */
+class NaiveVar final
+ : public Var, public common::ObjectPoolAllocatable<NaiveVar> {
+ public:
+ inline static NaiveVar* CastFromBase(Var* ptr) {
+ return ptr->Cast<NaiveVar>();
+ }
+}; // class NaiveVar
+
+
// implement naive engine
class NaiveEngine final : public Engine {
public:
@@ -71,8 +85,7 @@ class NaiveEngine final : public Engine {
// new variables
VarHandle NewVariable() override {
- size_t v = ++counter_;
- return reinterpret_cast<VarHandle>(v);
+ return NaiveVar::New();
}
OprHandle NewOperator(AsyncFn fn,
@@ -146,6 +159,10 @@ class NaiveEngine final : public Engine {
opr->opr_profile.reset(new profiler::ProfileOperator(opr->opr_name,
attrs.release()));
opr->opr_profile->start(exec_ctx.dev_type, exec_ctx.dev_id);
}
+ // increment mutable var version
+ for (auto var : mutable_vars) {
+ ++var->version_;
+ }
if (exec_ctx.dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
size_t dev_id = static_cast<size_t>(exec_ctx.dev_id);
@@ -171,8 +188,12 @@ class NaiveEngine final : public Engine {
}
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)
override {
- this->PushSync(delete_fn, exec_ctx, {}, {var},
- FnProperty::kNormal, 0, "DeleteVariable");
+ NaiveVar* naive_var = NaiveVar::CastFromBase(var);
+ this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete
on_complete) mutable {
+ delete_fn(ctx);
+ NaiveVar::Delete(naive_var);
+ on_complete();
+ }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
}
void WaitForVar(VarHandle var) override {
@@ -192,8 +213,6 @@ class NaiveEngine final : public Engine {
}
// whether action is completed
bool req_completed_;
- // counter
- std::atomic<size_t> counter_{0};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
// CPU stream
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc19..3a7587f 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -130,6 +130,9 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher
dispatcher) {
assert(pending_write_ != nullptr);
CHECK_EQ(num_pending_reads_, kWriteTriggered);
+ // increment version number
+ ++version_;
+
// really delete
if (to_delete_) {
VersionedVarBlock *head = pending_write_->next;
@@ -164,7 +167,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher
dispatcher) {
}
// This is outside of lock scope
// Be very carful, pending_write_ and num_pending_reads_
- // can change now, do not reply ont the two variables.
+ // can change now, do not rely on these two variables.
// The linked list \in [old_pending_write, end_of_read_chain)
// is already detached from this Var.
// So it is safe to modify these
@@ -196,6 +199,11 @@ inline bool ThreadedVar::ready_to_read() {
return this->is_ready_to_read();
}
+inline size_t ThreadedVar::version() {
+ std::lock_guard<std::mutex> lock{mutex_};
+ return this->version_;
+}
+
// implementation of threaded engine
ThreadedVar* ThreadedEngine::NewVariable() {
return ThreadedVar::New(VersionedVarBlock::New());
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 428f0d8..a2c1a2b 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,6 +162,7 @@ class ThreadedVar final
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
+ inline size_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 32b14b8..265554a 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -33,6 +33,7 @@
#include "../profiler/profiler.h"
#include "../common/utils.h"
#include "../common/exec_utils.h"
+#include "../operator/subgraph/subgraph_property.h"
namespace mxnet {
namespace exec {
@@ -42,6 +43,7 @@ using namespace mxnet::common;
GraphExecutor::GraphExecutor() {
log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
need_grad_ = false;
+ subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
}
GraphExecutor::~GraphExecutor() {
@@ -1428,6 +1430,146 @@ GraphExecutor::CachedSegOpr
GraphExecutor::CreateCachedSegOpr(size_t topo_start,
iter->c_str());
return ret;
}
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+ nnvm::ShapeVector arg_shapes,
+ nnvm::DTypeVector arg_dtypes,
+ StorageTypeVector arg_stypes,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>&
aux_state_ctxes) {
+ const auto& indexed_graph = g.indexed_graph();
+ const auto num_forward_inputs = indexed_graph.input_nodes().size();
+ g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+ aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+ g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+ if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+ HandleInferShapeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<nnvm::ShapeVector>("shape"));
+ }
+ g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+ if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+ HandleInferTypeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<nnvm::DTypeVector>("dtype"));
+ }
+ g = InferStorageType(std::move(g), std::move(arg_stypes),
"__storage_type__");
+ if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+ HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<StorageTypeVector>("storage_type"));
+ }
+ return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal
to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const nnvm::ShapeVector& arg_shapes,
+ const nnvm::DTypeVector& arg_dtypes,
+ const StorageTypeVector& arg_stypes,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>&
aux_state_ctxes) {
+ auto subgraph_prop =
op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+ nnvm::Symbol ret = src.Copy();
+ nnvm::Graph g;
+ g.outputs = ret.outputs;
+ g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+ ctx_map, in_arg_ctxes, aux_state_ctxes);
+ subgraph_prop->SetAttr("graph", g);
+ auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+ // assign a op name set to the subgraph property if it has been provided by
users
+ if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+ LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " <<
prop_name
+ << " has been assigned a value. Please make sure it is
initialized"
+ " only for the testing purpose.";
+ subgraph_prop->SetAttr("op_names", it->second);
+ }
+ g.attrs["subgraph_property"] =
std::make_shared<nnvm::any>(std::move(subgraph_prop));
+ g = ApplyPass(std::move(g), "PartitionGraph");
+ ret.outputs = g.outputs;
+ return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to
prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const std::unordered_map<std::string,
TShape>& arg_shape_map,
+ const std::unordered_map<std::string, int>&
arg_dtype_map,
+ const std::unordered_map<std::string, int>&
arg_stype_map,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>&
aux_state_ctxes) {
+ const std::vector<std::string> input_names =
src.ListInputNames(Symbol::kAll);
+ nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+ nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+ StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ auto it1 = arg_shape_map.find(input_names[i]);
+ if (arg_shape_map.end() != it1) {
+ arg_shapes[i] = it1->second;
+ }
+ auto it2 = arg_dtype_map.find(input_names[i]);
+ if (arg_dtype_map.end() != it2) {
+ arg_dtypes[i] = it2->second;
+ }
+ auto it3 = arg_stype_map.find(input_names[i]);
+ if (arg_stype_map.end() != it3) {
+ arg_stypes[i] = it3->second;
+ }
+ }
+ return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+ default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to
prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const std::vector<NDArray> &in_args,
+ const std::vector<NDArray> &aux_states,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map) {
+ const std::vector<std::string> input_names =
src.ListInputNames(Symbol::kAll);
+ const std::vector<std::string> arg_names =
src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+ const std::vector<std::string> aux_names =
src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+ CHECK_EQ(arg_names.size(), in_args.size());
+ CHECK_EQ(aux_names.size(), aux_states.size());
+ nnvm::ShapeVector arg_shapes; // all input shapes
+ arg_shapes.reserve(input_names.size());
+ nnvm::DTypeVector arg_dtypes; // all input dtypes
+ arg_dtypes.reserve(input_names.size());
+ StorageTypeVector arg_stypes; // all input stypes
+ arg_stypes.reserve(input_names.size());
+ std::vector<Context> in_arg_ctxes(in_args.size());
+ std::vector<Context> aux_state_ctxes(aux_states.size());
+
+ size_t i1 = 0, i2 = 0;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+ arg_shapes.push_back(aux_states[i2].shape());
+ arg_dtypes.push_back(aux_states[i2].dtype());
+ arg_stypes.push_back(aux_states[i2].storage_type());
+ aux_state_ctxes[i2] = aux_states[i2].ctx();
+ ++i2;
+ } else {
+ CHECK(i1 < arg_names.size());
+ CHECK_EQ(arg_names[i1], input_names[i]);
+ arg_shapes.push_back(in_args[i1].shape());
+ arg_dtypes.push_back(in_args[i1].dtype());
+ arg_stypes.push_back(in_args[i1].storage_type());
+ in_arg_ctxes[i1] = in_args[i1].ctx();
+ ++i1;
+ }
+ }
+ return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+ default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
} // namespace exec
Executor *Executor::SimpleBind(nnvm::Symbol symbol,
@@ -1447,6 +1589,11 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray>*
shared_buffer,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
+ if (!exec->subgraph_property().empty()) {
+ symbol = exec::PartitionGraph(symbol, exec->subgraph_property(),
arg_shape_map, arg_dtype_map,
+ arg_stype_map, default_ctx, group2ctx,
in_arg_ctxes,
+ aux_state_ctxes);
+ }
exec->Init(symbol, default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map, arg_stype_map,
@@ -1465,6 +1612,10 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
const std::vector<NDArray> &aux_states,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
+ if (!exec->subgraph_property().empty()) {
+ symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args,
aux_states,
+ default_ctx, group2ctx);
+ }
exec->Init(symbol, default_ctx, group2ctx,
in_args, arg_grad_store, grad_req_type, aux_states,
reinterpret_cast<Executor*>(shared_exec));
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 7b936c3..b94bb43 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -117,6 +117,8 @@ class GraphExecutor : public Executor {
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states) override;
+ const std::string& subgraph_property() const { return subgraph_property_; }
+
protected:
friend class mxnet::Imperative;
// Information about operational node
@@ -256,6 +258,8 @@ class GraphExecutor : public Executor {
std::unordered_set<std::string> cached_seg_opr_names_;
// verbose logging
bool log_verbose_ = false;
+ // subgraph property name
+ std::string subgraph_property_;
};
} // namespace exec
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
new file mode 100644
index 0000000..22058d5
--- /dev/null
+++ b/src/operator/subgraph/common.h
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+
+#include <string>
+#include <set>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../../executor/exec_pass.h"
+
+namespace mxnet {
+namespace op {
+
+inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
+ return sym.ListInputNames(nnvm::Symbol::kAll).size();
+}
+
+inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
+ return sym.ListOutputNames().size();
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListInputs(const
nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
+ return sym.ListInputNames(nnvm::Symbol::kAll);
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListOutputs(const
nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
+ return sym.ListOutputNames();
+}
+
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shapes,
+ std::vector<TShape> *out_shapes) {
+ using namespace exec;
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+ nnvm::Graph g;
+ g.outputs = subgraph_sym.outputs;
+ const auto& idx_g = g.indexed_graph();
+ CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
+ CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
+
+ // Put the input and output shapes to the shape vector.
+ nnvm::ShapeVector shapes(idx_g.num_node_entries());
+ const auto &input_nids = idx_g.input_nodes();
+ CHECK_EQ(input_nids.size(), in_shapes->size());
+ for (size_t i = 0; i < in_shapes->size(); i++) {
+ auto eid = idx_g.entry_id(input_nids[i], 0);
+ shapes[eid] = in_shapes->at(i);
+ }
+ CHECK_EQ(g.outputs.size(), out_shapes->size());
+ for (size_t i = 0; i < out_shapes->size(); i++) {
+ auto eid = idx_g.entry_id(g.outputs[i]);
+ shapes[eid] = out_shapes->at(i);
+ }
+
+ // Infer shape of the graph.
+ g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+ g = exec::InferShape(std::move(g));
+
+ // Copy the inferred shape back to the input shapes and the output shapes.
+ shapes = g.GetAttr<nnvm::ShapeVector>("shape");
+ // assign to in_shapes
+ for (size_t i = 0; i < in_shapes->size(); ++i) {
+ const auto eid = idx_g.entry_id(input_nids[i], 0);
+ SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
+ }
+ // assign to out_shapes
+ for (size_t i = 0; i < g.outputs.size(); ++i) {
+ const auto eid = idx_g.entry_id(g.outputs[i]);
+ SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
+ }
+ // Check if we have inferred the shapes correctly.
+ return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+ nnvm::Graph g;
+ g.outputs = subgraph_sym.outputs;
+ const auto& idx_g = g.indexed_graph();
+ CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
+ CHECK_EQ(idx_g.outputs().size(), out_types->size());
+
+ // Put the input and output data types to the dtype vector.
+ nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
+ const auto &input_nids = idx_g.input_nodes();
+ CHECK_EQ(input_nids.size(), in_types->size());
+ for (size_t i = 0; i < in_types->size(); i++) {
+ auto eid = idx_g.entry_id(input_nids[i], 0);
+ types[eid] = in_types->at(i);
+ }
+ CHECK_EQ(g.outputs.size(), out_types->size());
+ for (size_t i = 0; i < out_types->size(); i++) {
+ auto eid = idx_g.entry_id(g.outputs[i]);
+ types[eid] = out_types->at(i);
+ }
+
+ // Infer data type of the graph.
+ g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
+ g = exec::InferType(std::move(g));
+
+ types = g.GetAttr<nnvm::DTypeVector>("dtype");
+ // assign to in_types
+ for (size_t i = 0; i < in_types->size(); ++i) {
+ const auto eid = idx_g.entry_id(input_nids[i], 0);
+ TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
+ }
+ // assign to out_types
+ for (size_t i = 0; i < g.outputs.size(); ++i) {
+ const auto eid = idx_g.entry_id(g.outputs[i]);
+ TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
+ }
+ // Check if we have inferred the dtypes correctly.
+ return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_stypes,
+ std::vector<int>* out_stypes) {
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+ nnvm::Graph g;
+ g.outputs = subgraph_sym.outputs;
+ const auto& idx_g = g.indexed_graph();
+ CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
+ CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
+ exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
+
+ // Put the input and output storages to the storage vector.
+ StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
+ const auto &input_nids = idx_g.input_nodes();
+ CHECK_EQ(input_nids.size(), in_stypes->size());
+ for (size_t i = 0; i < in_stypes->size(); i++) {
+ auto eid = idx_g.entry_id(input_nids[i], 0);
+ stypes[eid] = in_stypes->at(i);
+ }
+ CHECK_EQ(g.outputs.size(), out_stypes->size());
+ for (size_t i = 0; i < out_stypes->size(); i++) {
+ auto eid = idx_g.entry_id(g.outputs[i]);
+ stypes[eid] = out_stypes->at(i);
+ }
+
+ // Infer storage type of the graph.
+ bool dev_match = g.attrs.count("dev_mask") &&
+ g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
+ if (!dev_match) {
+ g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
+ }
+ g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
+ g = exec::InferStorageType(std::move(g));
+
+ stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ // assign to in_types
+ for (size_t i = 0; i < in_stypes->size(); ++i) {
+ const auto eid = idx_g.entry_id(input_nids[i], 0);
+ STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes[eid]);
+ }
+
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ // assign to out_types
+ for (size_t i = 0; i < g.outputs.size(); ++i) {
+ const auto eid = idx_g.entry_id(g.outputs[i]);
+ STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
+ }
+ // Check if we have inferred the storages correctly.
+ return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
+}
+
+inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
+ return ExecType::kSubgraphExec;
+}
+
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const
nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+ const std::vector<std::string> input_names =
subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
+ const std::vector<std::string> immutable_input_names =
+ subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+ const std::vector<std::string> mutable_input_names =
+ subgraph_sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+ CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(),
input_names.size());
+ std::vector<uint32_t> ret;
+ size_t i1 = 0, i2 = 0;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ if (i1 < immutable_input_names.size() && input_names[i] ==
immutable_input_names[i1]) {
+ ++i1;
+ } else {
+ CHECK(i2 < mutable_input_names.size());
+ CHECK_EQ(input_names[i], mutable_input_names[i2]);
+ ++i2;
+ ret.push_back(i);
+ }
+ }
+ return ret;
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const
nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+ static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
+ std::set<ResourceRequest::Type> resource_types;
+ DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
+ if (!node->is_variable() && fresource.count(node->op())) {
+ for (ResourceRequest& r : fresource[node->op()](node->attrs)){
+ resource_types.insert(r.type);
+ }
+ }
+ });
+ return std::vector<ResourceRequest>(resource_types.begin(),
resource_types.end());
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_SUBGRAPH_COMMON_H_
diff --git a/src/operator/subgraph/default_subgraph_op.cc
b/src/operator/subgraph/default_subgraph_op.cc
new file mode 100644
index 0000000..d5fb7ee
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -0,0 +1,112 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+#define DEBUG_SUBGRAPH 0
+
+class DefaultSubgraphOperator {
+ public:
+ explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
+ subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
+ {"static_shape", "true"}}));
+ }
+
+ void Forward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+ void Backward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ LOG(FATAL) << "Not implemented";
+ }
+
+ private:
+ nnvm::Symbol subgraph_sym_;
+ CachedOpPtr subgraph_exec_;
+};
+
+void DefaultSubgraphOperator::Forward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ std::vector<NDArray> tmp_inputs = inputs;
+ std::vector<NDArray*> input_ptrs;
+ input_ptrs.reserve(inputs.size());
+ for (auto& nd : tmp_inputs) {
+ input_ptrs.push_back(&nd);
+ }
+ std::vector<NDArray> tmp_outputs = outputs;
+ std::vector<NDArray*> output_ptrs;
+ for (auto& nd : tmp_outputs) {
+ output_ptrs.push_back(&nd);
+ }
+#if DEBUG_SUBGRAPH
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
+ }
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
+ }
+#endif
+ subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
+}
+
+OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
+ Context ctx,
+ const std::vector<TShape>& in_shapes,
+ const std::vector<int>& in_types) {
+ return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
+}
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
+ op.Forward(ctx, inputs, req, outputs);
+}
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
+.set_num_inputs(DefaultSubgraphOpNumInputs)
+.set_num_outputs(DefaultSubgraphOpNumOutputs)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
DefaultSubgraphOpListInputs)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
DefaultSubgraphOpListOutputs)
+.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
+.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
+.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>",
DefaultSubgraphOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu
b/src/operator/subgraph/default_subgraph_op.cu
new file mode 100644
index 0000000..008826b
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file default_subgraph_op.cu
+ * \brief GPU Implementation of subgraph operations
+ */
+
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>",
DefaultSubgraphOpForward);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_property.cc
b/src/operator/subgraph/default_subgraph_property.cc
new file mode 100644
index 0000000..c8d3e9f
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_property.cc
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vector>
+#include <string>
+#include "./common.h"
+#include "./subgraph_property.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+ public:
+ explicit ContainOpSelector(const std::unordered_set<std::string>& op_names)
+ : op_names_(op_names) {}
+
+ virtual bool Select(const nnvm::Node &seed_node) {
+ return !seed_node.is_variable() && op_names_.count(seed_node.op()->name);
+ }
+
+ virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node
&input_node) {
+ return !input_node.is_variable() && op_names_.count(input_node.op()->name);
+ }
+
+ virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node
&output_node) {
+ return !output_node.is_variable() &&
op_names_.count(output_node.op()->name);
+ }
+ private:
+ const std::unordered_set<std::string>& op_names_;
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by
_default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+ static SubgraphPropertyPtr Create() { return
std::make_shared<DefaultSubgraphProperty>(); }
+ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const {
+ nnvm::NodePtr n = nnvm::Node::Create();
+ n->attrs.op = Op::Get("_default_subgraph_op");
+ n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+ n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+ return n;
+ }
+ virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+ return std::make_shared<ContainOpSelector>(
+ this->GetAttr<std::unordered_set<std::string>>("op_names"));
+ }
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph/partition_graph.cc
b/src/operator/subgraph/partition_graph.cc
new file mode 100644
index 0000000..315f7ee
--- /dev/null
+++ b/src/operator/subgraph/partition_graph.cc
@@ -0,0 +1,774 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+#include <queue>
+
+#include "./subgraph_property.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+#define DEBUG_SUBGRAPH 0
+
+namespace sg { // sg stands for subgraph
+
+struct SimpleNode;
+using SimpleNodePtr = std::shared_ptr<SimpleNode>;
+
+/*!
+ * \brief Node of the undirected graph which replicates the network structures
+ * of the computational graph. It is used to ease the graph traversal for
finding
+ * subgraphs.
+ */
+struct SimpleNode {
+ static SimpleNodePtr Create() {
+ return std::make_shared<SimpleNode>();
+ }
+ SimpleNode() : label(-1), node(nullptr) {}
+ /*! subgraph label */
+ int label;
+ /*! the original node in the computational graph it references*/
+ nnvm::Node* node;
+ /*!
+ * \brief output nodes of the current node
+ * key is node ptr and value is an array of indices standing for the entry
indices
+ * in key->inputs whose source is the current node.
+ */
+ std::unordered_map<nnvm::Node*, std::vector<size_t>> outputs;
+}; // struct SimpleNode
+
+#if DEBUG_SUBGRAPH
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+ std::string op_names = "";
+ for (size_t i = 0; i < simple_nodes.size(); ++i) {
+ op_names += simple_nodes[i]->node->attrs.name + ' ';
+ }
+ LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+ std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+ + ", index=" + std::to_string(entry.index) + ", version=" +
std::to_string(entry.version);
+ LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+ for (size_t i = 0; i < entries.size(); ++i) {
+ PrintNodeEntry(*entries[i]);
+ }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from
it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+ std::vector<SimpleNodePtr>* simple_nodes) {
+ const auto& indexed_graph = g.indexed_graph();
+ simple_nodes->reserve(indexed_graph.num_nodes());
+ DFSVisit(g.outputs, [&](const NodePtr& node) {
+ SimpleNodePtr sn = SimpleNode::Create();
+ sn->node = node.get();
+ for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+ const auto& e = sn->node->inputs[i];
+ const auto input_nid = indexed_graph.node_id(e.node.get());
+ CHECK_LT(input_nid, simple_nodes->size());
+ auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+ auto it = input_node_outputs.find(sn->node);
+ if (it == input_node_outputs.end()) {
+ input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+ } else {
+ it->second.push_back(i);
+ }
+ }
+ simple_nodes->emplace_back(std::move(sn));
+ });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<nnvm::Node*>* subgraph_nodes) {
+ for (auto n : *subgraph_nodes) {
+ const auto nid = g.indexed_graph().node_id(n);
+ simple_nodes[nid]->label = -1;
+ }
+ subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a
starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential
subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or
not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current
subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+ SubgraphSelectorPtr subgraph_selector,
+ const int label,
+ const size_t snid, // simple node id, this is a seed
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<nnvm::Node*>* subgraph_nodes,
+ std::unordered_set<const nnvm::Node*>* excluded_nodes =
nullptr) {
+ const auto& indexed_graph = g.indexed_graph();
+ std::queue<SimpleNode*> node_queue;
+ if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+ CHECK_EQ(simple_nodes[snid]->label, -1);
+ simple_nodes[snid]->label = label;
+ node_queue.push(simple_nodes[snid].get());
+ }
+ // key: nodes that serve as input/output nodes to the subgraph
+ // value: pair of vectors of nodes in the subgraph. The first vector
contains the
+ // output nodes of the key in the subgraph, and the second vector contains
the
+ // input nodes of the key in the subgraph.
+ // If a non-subgraph node has inputs from the subgraph and the other
non-subgraph node
+ // has outputs to the subgraph, and the first non-subgraph node is an
ancestor
+ // of the second non-subgraph node, there exits a cycle.
+ // When breaking the cycle, we want to start from removing the node with the
largest node id
+ // in the subgraph.
+ std::unordered_map<const nnvm::Node*,
+ std::pair<std::vector<const nnvm::Node*>,
+ std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+ while (!node_queue.empty()) {
+ SimpleNode* cur_node = node_queue.front();
+ node_queue.pop();
+ subgraph_nodes->push_back(cur_node->node);
+ // get qualified adjacent input nodes
+ for (auto& e : cur_node->node->inputs) {
+ const bool select_input = (!excluded_nodes ||
!excluded_nodes->count(e.node.get()))
+ && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+ if (select_input) {
+ // e.node is a subgraph node
+ const auto nid = indexed_graph.node_id(e.node.get());
+ CHECK_LT(nid, simple_nodes.size());
+ // this node has not been visited yet
+ if (simple_nodes[nid]->label == -1) {
+ simple_nodes[nid]->label = label;
+ node_queue.push(simple_nodes[nid].get());
+ }
+ } else {
+ // e.node is an input node of the subgraph
+ non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+ }
+ }
+ // get qualified output nodes
+ for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end();
++it) {
+ const bool select_output = (!excluded_nodes ||
!excluded_nodes->count(it->first))
+ && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+ if (select_output) {
+ // it->first is a subgraph node
+ const auto nid = indexed_graph.node_id(it->first);
+ CHECK_LT(nid, simple_nodes.size());
+ // this node has not been visited yet
+ if (simple_nodes[nid]->label == -1) {
+ simple_nodes[nid]->label = label;
+ node_queue.push(simple_nodes[nid].get());
+ }
+ } else {
+ // it->first is an output node of the subgraph
+ non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+ }
+ }
+ }
+ // prepare to check if there is a cycle
+ auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+ return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+ };
+ std::vector<const nnvm::Node*> non_subgraph_nodes;
+ non_subgraph_nodes.reserve(non_subgraph_node_map.size());
+ for (auto& kv : non_subgraph_node_map) {
+ auto& output_nodes = kv.second.first;
+ std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+ auto& input_nodes = kv.second.second;
+ std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+ non_subgraph_nodes.push_back(kv.first);
+ }
+ // check whether there is a cycle between the subgraph and its input/output
nodes
+ auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node*
descendant,
+ const std::vector<nnvm::Node*>& snodes) {
+ if (ancestor == descendant) return true;
+ std::stack<const nnvm::Node*> s;
+ s.push(descendant);
+ size_t count = 0;
+ while (!s.empty()) {
+ CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed.
There is probably"
+ " a loop in the graph";
+ ++count;
+ const nnvm::Node* top = s.top();
+ s.pop();
+ if (top == ancestor) {
+ return true;
+ }
+ for (const auto& entry : top->inputs) {
+ // when searching for the ancestor, the path cannot cross any subgraph
node
+ auto it = std::find(snodes.begin(), snodes.end(), entry.node.get());
+ if (it == snodes.end()) {
+ s.push(entry.node.get());
+ }
+ }
+ }
+ return false;
+ };
+ std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
+ int excluded_node_id = -1;
+ for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
+ auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
+ CHECK(it1 != non_subgraph_node_map.end());
+ auto& output_nodes = it1->second.first; // has been top sorted
+ auto& input_nodes = it1->second.second; // has been top sorted
+ if (!output_nodes.empty() && !input_nodes.empty()) {
+ // there is a loop between node i and the subgraph
+ const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+ indexed_graph.node_id(input_nodes.back()));
+ excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+ } else if (!input_nodes.empty()) {
+ // node i is an input to the subgraph, find out if there is a node j
+ // which is an output of the subgraph and also a child of node i.
+ for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
+ auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
+ CHECK(it2 != non_subgraph_node_map.end());
+ // i is topologically before j, j might be a direct/indirect output
node of i
+ CHECK_LT(indexed_graph.node_id(it1->first),
indexed_graph.node_id(it2->first));
+ if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first,
*subgraph_nodes)) {
+ // found a loop
+ const auto node_id =
std::max(indexed_graph.node_id(input_nodes.back()),
+
indexed_graph.node_id(it2->second.first.back()));
+ excluded_node_id = std::max(excluded_node_id,
static_cast<int>(node_id));
+ }
+ }
+ }
+ }
+
+ if (excluded_node_id != -1) {
+ CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+ CHECK_NE(excluded_node_id, static_cast<int>(snid))
+ << "A cycle is found in the computational graph between nodes "
+ << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+ << simple_nodes[snid]->node->attrs.name;
+ excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+ ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+ return false;
+ }
+ std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+ return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or
not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+ SubgraphSelectorPtr subgraph_selector,
+ const int label,
+ const size_t snid,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<nnvm::Node*>* subgraph_nodes) {
+ std::unordered_set<const nnvm::Node*> excluded_nodes;
+ const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+ size_t count = 0;
+ bool success = false;
+ while (!success && count < max_num_retry) {
+ success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+ subgraph_nodes, &excluded_nodes);
+ if (!success) {
+ CHECK(!excluded_nodes.empty());
+ std::string excluded_node_names;
+ for (auto node : excluded_nodes) {
+ excluded_node_names += node->attrs.name + ", ";
+ }
+ LOG(INFO) << "Found a cycle when BFS from node " <<
simple_nodes[snid]->node->attrs.name
+ << ". Excluding nodes " << excluded_node_names << "and
retrying";
+ }
+ ++count;
+ }
+ if (!success) {
+ LOG(INFO) << "Tried " << count << " times of finding subgraphs starting
from node "
+ << simple_nodes[snid]->node->attrs.name << " without success
because a loop "
+ "is always found between the subgraph and some other nodes.
Will treat "
+ "seed node " << simple_nodes[snid]->node->attrs.name
+ << "as a subgraph with one node";
+ CHECK(subgraph_nodes->empty());
+ simple_nodes[snid]->label = label;
+ subgraph_nodes->push_back(simple_nodes[snid]->node);
+ }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+ const std::vector<nnvm::Node*>& nodes,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<std::vector<SimpleNode*>>*
subgraphs,
+ size_t* subgraph_id) {
+ const auto& indexed_graph = g.indexed_graph();
+ std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+ auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode*
node2) {
+ return indexed_graph.node_id(node1->node) <
indexed_graph.node_id(node2->node);
+ };
+ for (auto node : nodes) {
+ if (!node_set.count(node)) {
+ // The node has been included in a subgraph
+ continue;
+ }
+ std::queue<nnvm::Node*> q;
+ q.push(node);
+ CHECK_EQ(node_set.erase(node), 1U);
+ subgraphs->emplace_back();
+ const auto nid = indexed_graph.node_id(node);
+ simple_nodes[nid]->label = *subgraph_id;
+ subgraphs->back().push_back(simple_nodes[nid].get());
+ while (!q.empty()) {
+ nnvm::Node* cur_node = q.front();
+ q.pop();
+ for (auto& e : cur_node->inputs) {
+ auto in_it = node_set.find(e.node.get());
+ if (in_it != node_set.end()) {
+ q.push(*in_it);
+ const auto in_nid = indexed_graph.node_id(*in_it);
+ simple_nodes[in_nid]->label = *subgraph_id;
+ subgraphs->back().push_back(simple_nodes[in_nid].get());
+ node_set.erase(in_it);
+ }
+ }
+ const auto cur_nid = indexed_graph.node_id(cur_node);
+ const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+ for (const auto& kv : cur_snode->outputs) {
+ const auto out_it = node_set.find(kv.first);
+ if (out_it != node_set.end()) {
+ q.push(*out_it);
+ const auto out_nid = indexed_graph.node_id(*out_it);
+ simple_nodes[out_nid]->label = *subgraph_id;
+ subgraphs->back().push_back(simple_nodes[out_nid].get());
+ node_set.erase(out_it);
+ }
+ }
+ }
+ ++(*subgraph_id);
+ std::sort(subgraphs->back().begin(), subgraphs->back().end(),
simple_node_cmp);
+ }
+ CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+ const SubgraphProperty &subg_prop,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+ const auto& indexed_graph = g->indexed_graph();
+ CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+ auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+ return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+ };
+ size_t subgraph_id = 0;
+ for (size_t i = 0; i < simple_nodes.size(); ++i) {
+ nnvm::Node* node = simple_nodes[i]->node;
+ auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+ if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+ // pre-select nodes that can be grouped in a subgraph
+ std::vector<nnvm::Node*> preselected_nodes;
+ PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i,
simple_nodes,
+ &preselected_nodes);
+
+ // filter out unqualified pre-selected nodes
+ std::vector<nnvm::Node*> filtered_nodes =
subgraph_selector->Filter(preselected_nodes);
+
+ // make sure filtered_nodes is a subset of preselected_nodes
+ for (const auto n : filtered_nodes) {
+ const auto nit = std::find(preselected_nodes.begin(),
preselected_nodes.end(), n);
+ CHECK(nit != preselected_nodes.end())
+ << "Node " << n->attrs.name << " is not found in the pre-selected
subgraph nodes."
+ " Please make sure that no new nodes were added in your subgraph"
+ " selector's Filter function";
+ }
+
+ // make sure nodes are sorted
+ std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
+
+ // reset node labels that are not in filtered nodes
+ for (const auto n : preselected_nodes) {
+ const auto nit = std::find(filtered_nodes.begin(),
filtered_nodes.end(), n);
+ if (nit == filtered_nodes.end()) {
+ simple_nodes[indexed_graph.node_id(n)]->label = -1;
+ }
+ }
+ // find out subgraphs from the filtered nodes
+ std::vector<std::vector<SimpleNode*>> subgraphs;
+ PostProcessNodeCandidates(*g, filtered_nodes, simple_nodes, &subgraphs,
&subgraph_id);
+ if (!subgraphs.empty()) {
+ subgraph_nodes->insert(subgraph_nodes->end(), subgraphs.begin(),
subgraphs.end());
+ }
+ }
+ }
+}
+
+/*!
+ * \brief Sorts entries according to their topological order.
+ * Note that entry ids cannot be used to sort entries.
+ * \param entry_top_order_map mapping from entry pointer to its topological
position in the graph
+ * \param entries Node entries to be sorted
+ */
+void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>&
entry_top_order_map,
+ std::vector<nnvm::NodeEntry*>* entries) {
+ auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
+ const auto it1 = entry_top_order_map.find(e1);
+ CHECK(it1 != entry_top_order_map.end());
+ const auto it2 = entry_top_order_map.find(e2);
+ CHECK(it2 != entry_top_order_map.end());
+ return it1->second < it2->second;
+ };
+ std::sort(entries->begin(), entries->end(), entry_cmp);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param input_entries input entries of the subgraph
+ */
+void FindInputEntries(const Graph& g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ const std::vector<SimpleNode*>& subgraph_nodes,
+ const std::unordered_map<const nnvm::NodeEntry*,
size_t>& entry_top_order_map,
+ std::vector<nnvm::NodeEntry*>* input_entries) {
+ const auto& indexed_graph = g.indexed_graph();
+ int label = -1;
+ for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+ if (label == -1) {
+ label = subgraph_nodes[i]->label;
+ } else {
+ CHECK_EQ(subgraph_nodes[i]->label, label);
+ }
+ auto& inputs = subgraph_nodes[i]->node->inputs;
+ for (size_t j = 0; j < inputs.size(); ++j) {
+ auto& e = inputs[j];
+ if (indexed_graph.exist(e.node.get())) {
+ // e's source node is not a subgraph node
+ const auto nid = indexed_graph.node_id(e.node.get());
+ // this is a node not belonging to the subgraph
+ if (simple_nodes[nid]->label != label) {
+ input_entries->push_back(&e);
+ }
+ } else {
+ // e's source node is a subgraph node.
+ // In this case, two subgraphs are adjacent.
+ input_entries->push_back(&e);
+ }
+ }
+ }
+ SortEntries(entry_top_order_map, input_entries);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param output_entries output entries of the subgraph
+ */
+void FindOutputEntries(Graph* g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ const std::vector<SimpleNode*>& subgraph_nodes,
+ const std::unordered_map<const nnvm::NodeEntry*,
size_t>&
+ entry_top_order_map,
+ std::vector<nnvm::NodeEntry*>* output_entries) {
+ if (subgraph_nodes.empty()) return;
+ const auto& indexed_graph = g->indexed_graph();
+ int label = -1;
+ for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+ if (label == -1) {
+ label = subgraph_nodes[i]->label;
+ } else {
+ CHECK_EQ(subgraph_nodes[i]->label, label);
+ }
+ for (auto it = subgraph_nodes[i]->outputs.begin();
+ it != subgraph_nodes[i]->outputs.end(); ++it) {
+ if (indexed_graph.exist(it->first)) {
+ // if the output node is a normal graph node (not a subgraph node)
+ const auto nid = indexed_graph.node_id(it->first);
+ // this is a node not belonging to the current subgraph
+ if (simple_nodes[nid]->label != label) {
+ for (auto idx : it->second) {
+ auto& e = simple_nodes[nid]->node->inputs[idx];
+ output_entries->push_back(&e);
+ }
+ }
+ } else {
+ // if the output node is a subgraph node
+ // two graphs are adjacent
+ for (auto idx : it->second) {
+ output_entries->push_back(&(it->first->inputs[idx]));
+ }
+ }
+ }
+ }
+ // Check if current subgraph contains a node which is the last node
+ // of the whole graph. If so, save its corresponding entry as well.
+ for (size_t i = 0; i < g->outputs.size(); ++i) {
+ auto& entry = g->outputs[i];
+ // The entry might has been updated as an output of
+ // a subgraph node. In this case, no need
+ // to check its source for the current subgraph. Otherwise,
+ // do the following.
+ if (indexed_graph.exist(entry.node.get())) {
+ const auto nid = indexed_graph.node_id(entry.node.get());
+ if (simple_nodes[nid]->label == label) {
+ output_entries->push_back(&entry);
+ }
+ }
+ }
+ SortEntries(entry_top_order_map, output_entries);
+}
+
+/*!
+ * \brief Given a computation graph and a set of input node entries, this
function cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
+ std::vector<nnvm::NodeEntry> *orig_entries,
+ const bool skip_var = false) {
+ orig_entries->resize(input_entries.size());
+ // map for creating unique var nodes for deduplicating entries from the same
node
+ std::unordered_map<std::string, int> name_count_map;
+ for (size_t i = 0; i < input_entries.size(); ++i) {
+ nnvm::NodeEntry *e = input_entries[i];
+ // If the node is a variable itself, we may want to skip the node.
+ if (e->node->is_variable() && skip_var) {
+ continue;
+ }
+
+ orig_entries->at(i) = *e;
+ nnvm::Symbol sym;
+ sym.outputs.push_back(*e);
+ const auto output_names = sym.ListOutputNames();
+ CHECK_EQ(output_names.size(), 1U);
+ const std::string& var_name = output_names[0];
+ auto it = name_count_map.find(var_name);
+ if (name_count_map.end() == it) {
+ name_count_map.emplace(var_name, 0);
+ } else {
+ ++(it->second);
+ }
+ nnvm::NodePtr n = nnvm::CreateVariableNode(var_name +
std::to_string(name_count_map[var_name]));
+ *e = nnvm::NodeEntry{n, 0, 0};
+ }
+}
+
+/*!
+ * \brief Replace a set of nodes belonging to the same subgraph with a
subgrpah node
+ * and keep the subgraph in the subgraph node. The input entries and output
entries
+ * of the subgraph node are kept in the same order as the subgraph's.
+ */
+void CreateSubgraphNode(Graph* g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ const std::vector<SimpleNode*>& subgraph_nodes,
+ const size_t subgraph_id,
+ std::unordered_map<const nnvm::NodeEntry*, size_t>*
entry_top_order_map) {
+#if DEBUG_SUBGRAPH
+ LOG(INFO) << "Searching for input entries...";
+#endif
+ std::vector<nnvm::NodeEntry*> input_entries;
+ FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map,
&input_entries);
+ std::vector<nnvm::NodeEntry> orig_input_entries;
+ CutGraphInputs(input_entries, &orig_input_entries, false);
+#if DEBUG_SUBGRAPH
+ PrintNodeEntries(input_entries);
+ LOG(INFO) << "Searching for output entries...";
+#endif
+ std::vector<nnvm::NodeEntry*> output_entries;
+ FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map,
&output_entries);
+
+ // Create a subgraph for the subgraph node
+ nnvm::Symbol sym;
+ sym.outputs.resize(output_entries.size());
+ for (size_t i = 0; i < output_entries.size(); ++i) {
+ sym.outputs[i] = *output_entries[i];
+ }
+ const SubgraphPropertyPtr& subg_prop =
g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
+ nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id);
+
+ // Connect the external nodes to the subgraph node.
+ for (size_t i = 0; i < output_entries.size(); ++i) {
+ *output_entries[i] = nnvm::NodeEntry{n, static_cast<uint32_t>(i), 0};
+ }
+ n->inputs = orig_input_entries;
+ const auto& indexed_graph = g->indexed_graph();
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ auto& e = n->inputs[i];
+ // update entry_top_order_map with newly created orig_input_entries
+ auto it = entry_top_order_map->find(input_entries[i]);
+ CHECK(it != entry_top_order_map->end());
+ entry_top_order_map->emplace(&e, it->second);
+ // update input entries' source simple nodes' outputs map
+ nnvm::Node* node = e.node.get();
+ if (indexed_graph.exist(node)) {
+ const auto nid = indexed_graph.node_id(node);
+ SimpleNode* sn = simple_nodes[nid].get();
+ for (SimpleNode* dest_node : subgraph_nodes) {
+ sn->outputs.erase(dest_node->node);
+ }
+ sn->outputs[n.get()].push_back(i);
+ }
+ }
+#if DEBUG_SUBGRAPH
+ PrintNodeEntries(output_entries);
+#endif
+}
+
+} // namespace sg
+
+/*!
+ * \brief Sort entries of all the nodes' inputs vectors in the topological
order.
+ * This is going to be used to sort input/output entries of subgraphs to keep
+ * the topological order unchanged.
+ */
+void TopSortEntries(const Graph& g,
+ std::unordered_map<const nnvm::NodeEntry*, size_t>*
entry_top_order_map) {
+ CHECK(entry_top_order_map != nullptr);
+ std::unordered_set<const nnvm::Node*> visited;
+ // tuple: (graph node, index of node's inputs, node entry as the output of
the graph node)
+ std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
+ auto in_degree = [] (const nnvm::Node* node)->size_t {
+ if (!node) {
+ return 0;
+ }
+ CHECK_EQ(node->control_deps.size(), 0U);
+ return node->inputs.size();
+ };
+ for (auto& e : g.outputs) {
+ nnvm::Node* node = e.node.get();
+ if (visited.count(node) == 0U) {
+ s.emplace(node, 0U, &e);
+ visited.insert(node);
+ } else {
+ // The entry's source node has been visited before.
+ // Marking the order for it.
+ entry_top_order_map->emplace(&e, entry_top_order_map->size());
+ }
+ while (!s.empty()) {
+ auto& top = s.top();
+ if (std::get<1>(top) == in_degree(std::get<0>(top))) {
+ // The node's inputs has been exhausted.
+ entry_top_order_map->emplace(std::get<2>(top),
entry_top_order_map->size());
+ s.pop();
+ } else {
+ // The node still has input entries not visited.
+ CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
+ auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
+ nnvm::Node* input_node = entry.node.get();
+ if (visited.count(input_node) == 0U) {
+ // The entry's source node has not been visited.
+ // Push the entry to the stack for marking order later.
+ s.emplace(input_node, 0U, &entry);
+ visited.insert(input_node);
+ } else {
+ // The entry's source node has been visited before.
+ // Marking the order for it.
+ entry_top_order_map->emplace(&entry, entry_top_order_map->size());
+ }
+ }
+ }
+ }
+}
+
+Graph PartitionGraph(Graph&& g) {
+ if (!g.HasAttr("subgraph_property")) { // treat the whole graph as a
subgraph
+ LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
+ "The original graph is returned.";
+ return g;
+ }
+ using namespace sg;
+ const SubgraphPropertyPtr& subg_prop =
g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
+ // top sort NodeEntry of all the nodes' inputs
+ std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
+ TopSortEntries(g, &entry_top_order_map);
+
+ // Create undirected graph for ease of finding subgraphs
+ std::vector<SimpleNodePtr> simple_nodes;
+ CreateSimpleGraph(g, &simple_nodes);
+ std::vector<std::vector<SimpleNode*>> subgraph_nodes;
+ FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes);
+ for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+#if DEBUG_SUBGRAPH
+ std::set<SimpleNode*> simple_node_set(subgraph_nodes[i].begin(),
subgraph_nodes[i].end());
+ CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
+ PrintSubgraph(subgraph_nodes[i]);
+#endif
+ CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i,
&entry_top_order_map);
+ }
+ return g;
+}
+
+NNVM_REGISTER_PASS(PartitionGraph)
+.describe("Partition a graph according to the user defined rules "
+ "in a derived class of SubgraphProperty")
+.set_body(PartitionGraph)
+.set_change_graph(true);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph/subgraph_property.h
b/src/operator/subgraph/subgraph_property.h
new file mode 100644
index 0000000..cfbc1f8
--- /dev/null
+++ b/src/operator/subgraph/subgraph_property.h
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+
+#include <nnvm/node.h>
+#include <dmlc/base.h>
+#include <dmlc/thread_local.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for the graph partitioning algorithm to select
+ * nodes to subgraphs.
+ * The algorithm first sorts all the nodes in topological order, and then
+ * loops through the sorted nodes and tries to find a subgraph starting
+ * from each node (we call it a seed node) that satisfies the following two
conditions:
+ * 1. The node has not been selected before.
+ * 2. The function Select is called on the node and returns true.
+ *
+ * Expanding from this seed node, we do BFS to traverse the graph.
+ * During the traversal, we call SelectInput and SelectOutput to determine
+ * if a neighboring node of the current node should be selected as a candidate
for the subgraph.
+ * The search continues when a new node is selected as a candidate, and
terminates when no more
+ * qualified nodes are found. When the search ends, all of the candidate nodes
will
+ * be passed to the function Filter to finalize the subgraph. The filtering
gives
+ * developers the last opportunity to drop off some of the candidate nodes.
+ * By default, Filter returns all nodes as the subgraph nodes.
+ * If the pre-selected subgraph becomes disconnected because some
+ * nodes are filtered out in the Filter function, the algorithm will
automatically convert
+ * the rest of the nodes to multiple valid subgraphs based upon their
connectivity.
+ */
+class SubgraphSelector {
+ public:
+ virtual ~SubgraphSelector() {}
+ /*!
+ * \brief Determines if to search for other nodes to form a subgraph from
the seed_node.
+ */
+ virtual bool Select(const nnvm::Node &seed_node) = 0;
+ /*!
+ * \brief Determines if to select input_node when traverse to the cur_node.
+ * \param cur_node the node for determining whether its input_node should be
selected
+ * \param input_node the input node of the cur_node
+ */
+ virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node
&input_node) = 0;
+ /*!
+ * \brief Determines if to select output_node when traverse to the cur_node.
+ * \param cur_node the node for determining whether its output_node should
be selected
+ * \param output_node the output node of the cur_node
+ */
+ virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node
&output_node) = 0;
+ // Post processes pre-selected subgraph nodes. Return a list of nodes that
+ // users want to keep in subgraph(s).
+ virtual std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>&
candidates) {
+ return candidates;
+ }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into
subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+ // the criteria of selecting the subgraph nodes.
+ virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+ // create an nnvm node for a given subgraph. Here users can customize how to
+ // execute the operators in the subgraph.
+ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+ const int subgraph_id = 0) const =
0;
+ // set an attr with name in the attr map
+ template<typename T>
+ SubgraphProperty& SetAttr(const std::string& name, const T& value) {
+ attrs_[name] = std::make_shared<dmlc::any>(value);
+ return *this;
+ }
+ // get the attr with the name
+ template<typename T>
+ const T& GetAttr(const std::string& name) const {
+ auto it = attrs_.find(name);
+ CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in
SubgraphProperty";
+ return nnvm::get<T>(*it->second);
+ }
+ protected:
+ std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+class SubgraphPropertyRegistry {
+ public:
+ typedef SubgraphPropertyPtr (*SubgraphPropertyCreateFn)(void);
+ static SubgraphPropertyRegistry* Get() {
+ static SubgraphPropertyRegistry inst;
+ return &inst;
+ }
+
+ SubgraphPropertyPtr CreateSubgraphProperty(const std::string& name) {
+ auto it = prop_fn_map_.find(name);
+ CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name
+ << " is not found in
SubgraphPropertyRegistry";
+ return it->second();
+ }
+
+ SubgraphPropertyCreateFn __REGISTER_OR_GET__(const std::string& name,
+ SubgraphPropertyCreateFn fn) {
+ if (prop_fn_map_.count(name) == 0U) {
+ return __REGISTER__(name, fn);
+ } else {
+ return prop_fn_map_.at(name);
+ }
+ }
+
+ private:
+ SubgraphPropertyCreateFn __REGISTER__(const std::string& name,
SubgraphPropertyCreateFn fn) {
+ CHECK_EQ(prop_fn_map_.count(name), 0U) << "Subgraph property " << name
+ << " has been registered";
+ prop_fn_map_[name] = fn;
+ return prop_fn_map_[name];
+ }
+
+ SubgraphPropertyRegistry() = default;
+ SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete;
+ SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete;
+ SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) =
delete;
+ std::unordered_map<std::string, SubgraphPropertyCreateFn> prop_fn_map_;
+};
+
+// This op name set is for setting the names of operators that should be
grouped into
+// subgraphs. In practice, every backend accelerator should have a predefined
name set.
+// This set is only used for the testing purpose.
+// key: property name, value: op name set
+typedef dmlc::ThreadLocalStore<std::unordered_map<std::string,
std::unordered_set<std::string>>>
+ SubgraphPropertyOpNameSet;
+
+#define MXNET_REGISTER_SUBGRAPH_PROPERTY(Name, SubgraphPropertyType) \
+ static DMLC_ATTRIBUTE_UNUSED auto __make_ ## SubgraphPropertyType ## _ ##
Name ## __ = \
+ SubgraphPropertyRegistry::Get()->__REGISTER_OR_GET__(#Name,
&SubgraphPropertyType::Create)
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
diff --git a/tests/cpp/engine/threaded_engine_test.cc
b/tests/cpp/engine/threaded_engine_test.cc
index 92d0958..6d669c1 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -275,6 +275,64 @@ TEST(Engine, basics) {
LOG(INFO) << "All pass";
}
+TEST(Engine, VarVersion) {
+ const size_t num_engines = 3;
+ std::vector<mxnet::Engine*> engines(num_engines);
+ engines[0] = mxnet::engine::CreateNaiveEngine();
+ engines[1] = mxnet::engine::CreateThreadedEnginePooled();
+ engines[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+ std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled",
"ThreadedEnginePerDevice"};
+ for (size_t k = 0; k < num_engines; ++k) {
+ auto engine = engines[k];
+ std::vector<mxnet::Engine::OprHandle> oprs;
+
+ LOG(INFO) << "Testing var as a read dependency in " << type_names[k];
+ auto var = engine->NewVariable();
+ EXPECT_EQ(var->version(), 0U);
+ for (int i = 0; i < 10; ++i) {
+ oprs.push_back(engine->NewOperator(
+ [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+ Foo(ctx, i);
+ cb();
+ },
+ {var}, {}));
+ engine->Push(oprs.at(i), mxnet::Context{});
+ }
+ engine->WaitForAll();
+ EXPECT_EQ(var->version(), 0U);
+ for (auto&& i : oprs) {
+ engine->DeleteOperator(i);
+ }
+ engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+ engine->WaitForAll();
+
+ LOG(INFO) << "Testing var as a write dependency in " << type_names[k];
+ var = engine->NewVariable();
+ EXPECT_EQ(var->version(), 0U);
+ oprs.clear();
+ for (int i = 0; i < 10; ++i) {
+ oprs.push_back(engine->NewOperator(
+ [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+ Foo(ctx, i);
+ cb();
+ },
+ {}, {var}));
+ engine->Push(oprs.at(i), mxnet::Context{});
+ }
+ engine->WaitForAll();
+ EXPECT_EQ(var->version(), 10U);
+ for (auto&& i : oprs) {
+ engine->DeleteOperator(i);
+ }
+ engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+ engine->WaitForAll();
+
+ var = nullptr;
+ oprs.clear();
+ LOG(INFO) << "All pass";
+ }
+}
+
#ifdef _OPENMP
struct TestSaveAndRestoreOMPState {
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index 5612b0a..0ff33e1 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -41,6 +41,7 @@ from test_exc_handling import *
from test_sparse_ndarray import *
from test_sparse_operator import *
from test_ndarray import *
+from test_subgraph_op import *
set_default_context(mx.gpu(0))
del test_support_vector_machine_l1_svm # noqa
diff --git a/tests/python/unittest/test_subgraph_op.py
b/tests/python/unittest/test_subgraph_op.py
new file mode 100644
index 0000000..40d609a
--- /dev/null
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import ctypes
+import mxnet as mx
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array,
c_str
+from mxnet.symbol import Symbol
+import numpy as np
+from mxnet.test_utils import assert_almost_equal
+
+
+def test_subgraph_exe():
+ def _check_subgraph_exe1(sym, op_names):
+ """Use the partitioned sym to simple_bind an executor and compare the
outputs
+ with those of the original executor"""
+ out = SymbolHandle()
+ check_call(_LIB.MXPartitionGraphByOpNames(sym.handle,
c_str('default'), mx_uint(len(op_names)),
+ c_str_array(op_names),
ctypes.byref(out)))
+
+ partitioned_sym = Symbol(out)
+ assert partitioned_sym.list_inputs() == sym.list_inputs()
+ assert partitioned_sym.list_arguments() == sym.list_arguments()
+ assert partitioned_sym.list_auxiliary_states() ==
sym.list_auxiliary_states()
+ exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+ partitioned_exe =
partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+ input_names = sym.list_inputs()
+ for name in input_names:
+ if name in exe.arg_dict:
+ exe.arg_dict[name][:] =
mx.nd.random.uniform(shape=exe.arg_dict[name].shape)
+ partitioned_exe.arg_dict[name][:] = exe.arg_dict[name]
+ else:
+ assert name in exe.aux_dict
+ exe.aux_dict[name][:] =
mx.nd.random.uniform(shape=exe.aux_dict[name].shape)
+ partitioned_exe.aux_dict[name][:] = exe.aux_dict[name]
+ exe.forward()
+ partitioned_exe.forward()
+ assert len(exe.outputs) == len(partitioned_exe.outputs)
+ for i in range(len(exe.outputs)):
+ assert_almost_equal((exe.outputs[i] -
partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+ np.zeros(shape=(1,)))
+
+ def _check_subgraph_exe2(sym, op_names):
+ """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph
partitioning in simple_bind
+ and compare results of the partitioned sym and the original sym."""
+ def get_executor(sym, subgraph_backend=None, op_names=None,
original_exec=None):
+ if subgraph_backend is not None:
+ os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend),
mx_uint(len(op_names)),
+
c_str_array(op_names)))
+ exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+ input_names = sym.list_inputs()
+ for name in input_names:
+ if name in exe.arg_dict:
+ exe.arg_dict[name][:] =
mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
+ if original_exec is None else
original_exec.arg_dict[name]
+ else:
+ assert name in exe.aux_dict
+ exe.aux_dict[name][:] =
mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
+ if original_exec is None else
original_exec.aux_dict[name]
+ exe.forward()
+ if subgraph_backend is not None:
+
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+ del os.environ['MXNET_SUBGRAPH_BACKEND']
+ return exe
+
+ original_exec = get_executor(sym)
+ partitioned_exec = get_executor(sym, 'default', op_names,
original_exec)
+ outputs1 = original_exec.outputs
+ outputs2 = partitioned_exec.outputs
+ assert len(outputs1) == len(outputs2)
+ for i in range(len(outputs1)):
+ assert_almost_equal((outputs1[i] -
outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+ def _check_subgraph_exe3(sym, op_names):
+ """Use the partitioned sym to bind an executor and compare the outputs
+ with those of the original executor"""
+ out = SymbolHandle()
+ check_call(_LIB.MXPartitionGraphByOpNames(sym.handle,
c_str('default'), mx_uint(len(op_names)),
+ c_str_array(op_names),
ctypes.byref(out)))
+
+ partitioned_sym = Symbol(out)
+ input_names = sym.list_inputs()
+ arg_names = sym.list_arguments()
+ aux_names = sym.list_auxiliary_states()
+ assert partitioned_sym.list_inputs() == input_names
+ assert partitioned_sym.list_arguments() == arg_names
+ assert partitioned_sym.list_auxiliary_states() == aux_names
+ arg_shapes, _, aux_shapes = sym.infer_shape()
+ arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+ aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+ exe = sym.bind(ctx=mx.current_context(), args=arg_array,
aux_states=aux_array, grad_req='null')
+ partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(),
args=arg_array,
+ aux_states=aux_array,
grad_req='null')
+ exe.forward()
+ partitioned_exe.forward()
+ assert len(exe.outputs) == len(partitioned_exe.outputs)
+ for i in range(len(exe.outputs)):
+ assert_almost_equal((exe.outputs[i] -
partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+ np.zeros(shape=(1,)))
+
+ def _check_subgraph_exe4(sym, op_names):
+ """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph
partitioning in bind
+ and compare results of the partitioned sym and the original sym."""
+ def get_executor(sym, subgraph_backend=None, op_names=None,
original_exec=None):
+ if subgraph_backend is not None:
+ os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend),
mx_uint(len(op_names)),
+
c_str_array(op_names)))
+ arg_shapes, _, aux_shapes = sym.infer_shape()
+ if subgraph_backend is None:
+ arg_array = [mx.nd.random.uniform(shape=shape) for shape in
arg_shapes]
+ aux_array = [mx.nd.random.uniform(shape=shape) for shape in
aux_shapes]
+ else:
+ arg_array = None
+ aux_array = None
+ exe = sym.bind(ctx=mx.current_context(),
+ args=arg_array if subgraph_backend is None else
original_exec.arg_arrays,
+ aux_states=aux_array if subgraph_backend is None
else original_exec.aux_arrays,
+ grad_req='null')
+ exe.forward()
+ if subgraph_backend is not None:
+
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+ del os.environ['MXNET_SUBGRAPH_BACKEND']
+ return exe
+
+ original_exec = get_executor(sym)
+ partitioned_exec = get_executor(sym, 'default', op_names,
original_exec)
+ outputs1 = original_exec.outputs
+ outputs2 = partitioned_exec.outputs
+ assert len(outputs1) == len(outputs2)
+ for i in range(len(outputs1)):
+ assert_almost_equal((outputs1[i] -
outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+ def check_subgraph_exe(sym, op_names):
+ _check_subgraph_exe1(sym, op_names)
+ _check_subgraph_exe2(sym, op_names)
+ _check_subgraph_exe3(sym, op_names)
+ _check_subgraph_exe4(sym, op_names)
+
+ def test_network_structure_1():
+ data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
+ data2 = mx.sym.var('data2')
+ conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True,
kernel=(2, 2), num_filter=1)
+ conv2 = mx.sym.Convolution(data=data2, no_bias=True, kernel=(1, 1),
num_filter=1)
+ out = mx.sym.Group([conv1, conv2])
+ check_subgraph_exe(out, ['Convolution'])
+
+ def test_network_structure_2():
+ # this tests whether the partitioning algorithm can deal with cycles
+ data = mx.sym.var('data', shape=(2, 3, 10, 10))
+ ret = mx.sym.exp(data)
+ ret1 = mx.sym.cos(ret)
+ ret2 = mx.sym.sin(ret)
+ ret = ret1 + ret2
+ check_subgraph_exe(ret, ['exp', 'sin', '_Plus', 'elemwise_add',
'_plus'])
+ check_subgraph_exe(ret, ['exp', 'cos', '_Plus', 'elemwise_add',
'_plus'])
+
+ def test_network_structure_3():
+ # this tests whether the partitioned sym can distinguish in_args and
aux_states
+ data = mx.sym.var('data', shape=(2, 3, 10, 10))
+ ret = mx.sym.exp(data)
+ ret1 = mx.sym.cos(ret)
+ ret2 = mx.sym.sin(ret)
+ ret = ret1 + ret2
+ ret = mx.sym.BatchNorm(ret)
+ ret = mx.sym.BatchNorm(ret)
+ check_subgraph_exe(ret, ['exp', 'sin', '_Plus', 'elemwise_add',
'_plus'])
+ check_subgraph_exe(ret, ['exp', 'cos', '_Plus', 'elemwise_add',
'_plus'])
+ check_subgraph_exe(ret, ['exp', 'sin', '_Plus', 'elemwise_add',
'_plus', 'BatchNorm'])
+ check_subgraph_exe(ret, ['exp', 'cos', '_Plus', 'elemwise_add',
'_plus', 'BatchNorm'])
+ check_subgraph_exe(ret, ['exp', 'BatchNorm'])
+ check_subgraph_exe(ret, ['BatchNorm'])
+
+ def test_network_structure_4():
+ # the last op has multiple duplicate outputs
+ data = mx.sym.var('data', shape=(2, 3, 10, 10))
+ ret = mx.sym.exp(data)
+ ret = mx.sym.Group([ret, ret, ret])
+ check_subgraph_exe(ret, ['exp'])
+
+ def test_network_structure_5():
+ # the subgraph has two duplicate input entries
+ data = mx.sym.var('data', shape=(2, 3, 10, 10))
+ ret = data + data
+ check_subgraph_exe(ret, ['_plus', '_Plus', 'elemwise_add'])
+
+ def test_network_structure_6():
+ def get_graph():
+ data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10),
dtype=np.float32)
+ data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
+ data3 = mx.sym.sin(data2)
+ conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2),
num_filter=1)
+ rets = [(conv, []),
+ (conv, [mx.sym.sin.__name__]),
+ (conv, [mx.sym.Convolution.__name__]),
+ (conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__])]
+ return rets
+
+ for sym, op_names in get_graph():
+ check_subgraph_exe(sym, op_names)
+
+ def test_network_structure_7():
+ # in this graph, the subgraph node and the other two external nodes
form a cycle
+ data = mx.sym.Variable('data', shape=(1,))
+ ret1 = mx.sym.sin(data)
+ ret2 = mx.sym.cos(ret1)
+ for _ in range(5):
+ ret2 = mx.sym.cos(ret2)
+ ret = ret1 + ret2
+ check_subgraph_exe(ret, ['sin', 'elemwise_add', '_plus', '_Plus'])
+
+ test_network_structure_1()
+ test_network_structure_2()
+ test_network_structure_3()
+ test_network_structure_4()
+ test_network_structure_5()
+ test_network_structure_6()
+ test_network_structure_7()
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()