This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 369b66d Improve cached_op performance for static mode (#14785)
369b66d is described below
commit 369b66d0f10ba479ce96f78f7c838bd7bc41d951
Author: Zhennan Qin <[email protected]>
AuthorDate: Sat Apr 27 01:52:38 2019 +0800
Improve cached_op performance for static mode (#14785)
* Fix cached_op
* try to fix ci
* Fix CI
* Fix ci
---
src/executor/attach_op_execs_pass.cc | 8 ++++++--
src/executor/exec_pass.h | 9 ++++++++-
src/imperative/cached_op.cc | 10 ++++++----
src/imperative/imperative_utils.h | 26 ++++++++++++--------------
4 files changed, 32 insertions(+), 21 deletions(-)
diff --git a/src/executor/attach_op_execs_pass.cc
b/src/executor/attach_op_execs_pass.cc
index b04d132..8f47bc2 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector*
p_state, size_t i) {
using nnvm::DTypeVector;
using mxnet::ShapeVector;
using nnvm::FMutateInputs;
@@ -302,6 +302,10 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret,
size_t i) {
OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
+ if (p_state) {
+ CHECK_GT(p_state->size(), i);
+ p_state->at(i) = state;
+ }
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
@@ -359,7 +363,7 @@ Graph AttachOpExecs(Graph g) {
const auto& idx = g.indexed_graph();
OpExecVector ret(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
- CreateOpExecs(g, &ret, i);
+ CreateOpExecs(g, &ret, nullptr, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index dd41323..7e5130f 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -99,6 +99,12 @@ class OpExecutor {
using OpExecVector = std::vector<std::shared_ptr<OpExecutor> >;
/*!
+ * \brief per node vector of operator states.
+ * \note stored under attribute "op_states"
+ */
+using OpStateVector = std::vector<OpStatePtr>;
+
+/*!
* \brief per node context vector
* \node stored under "context"
*/
@@ -115,9 +121,10 @@ using DevMaskVector = std::vector<int>;
*
* \param g input graph
* \param p_ret OpExecVector for input and output
+ * \param p_state OpStateVector if it has.
* \param i the id of the node
*/
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector*
p_state, size_t i);
/*!
* \brief Attach OpExecutor to the graph attributes.
*
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index c9215c5..7a5ed21 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context&
default_ctx,
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
- if (erase_result) {
+ if (contain_dynamic_shape && erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
@@ -603,7 +603,7 @@ void CachedOp::StaticInitExec(
}
} else {
for (size_t i = start_nid; i < end_nid; ++i) {
- exec::CreateOpExecs(g, &state.execs, i);
+ exec::CreateOpExecs(g, &state.execs, &state.op_states, i);
}
exec::AttachOpResources(g, state.execs, start_nid, end_nid);
@@ -705,8 +705,10 @@ void CachedOp::StaticRunOps(
arg_shapes.emplace_back(ndinput->shape());
arg_dtypes.emplace_back(ndinput->dtype());
}
- state.op_states[i] = createop[node.source->op()](
- node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
+ if (!state.op_states[i]) {
+ state.op_states[i] =
+ createop[node.source->op()](node.source->attrs, default_ctx,
arg_shapes, arg_dtypes);
+ }
Imperative::Get()->InvokeOp(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, state.op_states[i]);
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index 9d4e4bd..5c97068 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -595,23 +595,21 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g,
mxnet::ShapeVector&& shapes,
*contain_unknown = false;
}
nnvm::Graph& g = *p_g;
- if (use_inputs) {
- if (g.attrs.count("shape_inputs") &&
- g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes) return true;
- } else if (g.attrs.count("shape")) {
+ if (g.attrs.count("shape")) {
const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
- CHECK_EQ(prev_shapes.size(), shapes.size());
- bool match = true;
- for (size_t i = 0; i < shapes.size(); ++i) {
- if (i == entry_range.first) {
- i = entry_range.second;
- if (i >= shapes.size()) break;
+ if (prev_shapes.size() == shapes.size()) {
+ bool match = true;
+ for (size_t i = 0; i < shapes.size(); ++i) {
+ if (i == entry_range.first) {
+ i = entry_range.second;
+ if (i >= shapes.size()) break;
+ }
+ if (shapes[i] == prev_shapes[i]) continue;
+ match = false;
+ break;
}
- if (shapes[i] == prev_shapes[i]) continue;
- match = false;
- break;
+ if (match) return true;
}
- if (match) return true;
}
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");