This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 09202f7f Improve static cached_op optimization (#15187)
09202f7f is described below
commit 09202f7f261954383aa387144524d38f83f18d06
Author: Xinyu Chen <[email protected]>
AuthorDate: Thu Jun 13 11:08:15 2019 +0800
Improve static cached_op optimization (#15187)
* Fix cached op
Change-Id: If90c6f0997548ffd5daa67cc18bab7405f24213b
* Fix UT
* trigger
---
src/imperative/cached_op.cc | 2 +-
src/imperative/imperative_utils.h | 5 ++++-
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 07c7871..b49cad4 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -290,7 +290,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context&
default_ctx,
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
- if (contain_dynamic_shape && erase_result) {
+ if (!config_.static_shape && erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index 5cb805c..4e63e4d 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -595,7 +595,10 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g,
mxnet::ShapeVector&& shapes,
*contain_unknown = false;
}
nnvm::Graph& g = *p_g;
- if (g.attrs.count("shape")) {
+ if (use_inputs) {
+ if (g.attrs.count("shape_inputs") &&
g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes)
+ return true;
+ } else if (g.attrs.count("shape")) {
const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
if (prev_shapes.size() == shapes.size()) {
bool match = true;