This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 0fe04e9 add SameType as default type inference function in imperative mode (#9942) 0fe04e9 is described below commit 0fe04e9d778ac9c55e933ff3f21c4ddf28a4a101 Author: Ziyue Huang <zyhuan...@gmail.com> AuthorDate: Wed Mar 7 13:43:32 2018 +0800 add SameType as default type inference function in imperative mode (#9942) * add SameType as default in imperative mode * move SameType and DefaultStorageType to src/common --- src/common/exec_utils.h | 61 +++++++++++++++++++++++++++++++++++ src/executor/exec_pass.h | 11 ------- src/executor/infer_graph_attr_pass.cc | 56 ++------------------------------ src/imperative/imperative_utils.h | 14 +++++--- 4 files changed, 73 insertions(+), 69 deletions(-) diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 5fd1a9b..29537d3 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -169,6 +169,67 @@ inline void CastNonDefaultStorage(const std::vector<NDArray>& src, } } } + +/*! \brief The default type inference function, which assigns all undefined + * types to the same type of one of the inputs or outputs. + */ +inline bool SameType(const nnvm::NodeAttrs& attrs, + std::vector<int> *iattr, + std::vector<int> *oattr) { + int def_v = -1; + for (int v : *oattr) { + if (v != -1) { + def_v = v; break; + } + } + if (def_v == -1) { + for (int v : *iattr) { + if (v != -1) { + def_v = v; break; + } + } + } + if (def_v == -1) return false; + for (int& v : *oattr) { + v = def_v; + } + for (int& v : *iattr) { + v = def_v; + } + return true; +} + + +/*! \brief The default storage type inference function, which assigns all undefined + * storage types to kDefaultStorage. If all of input and output storage types + * are kDefaultStorage, DispatchMode::kFCompute is assigned to dispatch_mode. Otherwise, + * DispatchMode::kFComputeFallback is assigned to dispatch_mode. + */ +inline bool DefaultStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int> *iattr, + std::vector<int> *oattr) { + bool fallback = false; + for (int& v : *oattr) { + if (v == -1) v = kDefaultStorage; + if (v != kDefaultStorage) fallback = true; + } + for (int& v : *iattr) { + if (v == -1) v = kDefaultStorage; + if (v != kDefaultStorage) fallback = true; + } + if (*dispatch_mode == DispatchMode::kUndefined) { + if (fallback) { + *dispatch_mode = DispatchMode::kFComputeFallback; + } else { + *dispatch_mode = DispatchMode::kFCompute; + } + } + return true; +} + + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_EXEC_UTILS_H_ diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index bf4b147..99b1b16 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -178,17 +178,6 @@ Graph InferStorageType(Graph&& graph, StorageTypeVector&& storage_type_inputs = StorageTypeVector(), const std::string& storage_type_attr_key = ""); -/*! \brief The default storage type inference function, which assigns all undefined - * storage types to kDefaultStorage. If all of input and output storage types - * are kDefaultStorage, DispatchMode::kFCompute is assigned to dispatch_mode. Otherwise, - * DispatchMode::kFComputeFallback is assigned to dispatch_mode. - */ -bool DefaultStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector<int> *iattr, - std::vector<int> *oattr); - } // namespace exec } // namespace mxnet diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 01fab22..191fbe9 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -26,6 +26,7 @@ #include <mxnet/graph_attr_types.h> #include "./exec_pass.h" #include "../operator/operator_common.h" +#include "../common/exec_utils.h" namespace mxnet { namespace exec { @@ -321,57 +322,6 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, return ret; } -// inference fucntion for same type -inline bool SameType(const nnvm::NodeAttrs& attrs, - std::vector<int> *iattr, - std::vector<int> *oattr) { - int def_v = -1; - for (int v : *oattr) { - if (v != -1) { - def_v = v; break; - } - } - if (def_v == -1) { - for (int v : *iattr) { - if (v != -1) { - def_v = v; break; - } - } - } - if (def_v == -1) return false; - for (int& v : *oattr) { - v = def_v; - } - for (int& v : *iattr) { - v = def_v; - } - return true; -} - -inline bool DefaultStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector<int> *iattr, - std::vector<int> *oattr) { - bool fallback = false; - for (int& v : *oattr) { - if (v == -1) v = kDefaultStorage; - if (v != kDefaultStorage) fallback = true; - } - for (int& v : *iattr) { - if (v == -1) v = kDefaultStorage; - if (v != kDefaultStorage) fallback = true; - } - if (*dispatch_mode == DispatchMode::kUndefined) { - if (fallback) { - *dispatch_mode = DispatchMode::kFComputeFallback; - } else { - *dispatch_mode = DispatchMode::kFCompute; - } - } - return true; -} - nnvm::Graph InferShape(nnvm::Graph&& graph, nnvm::ShapeVector&& shape_inputs, const std::string& shape_attr_key) { @@ -405,7 +355,7 @@ nnvm::Graph InferType(nnvm::Graph&& graph, "FInferType", "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, - SameType, true, nullptr); + common::SameType, true, nullptr); } nnvm::Graph InferStorageType(nnvm::Graph&& graph, @@ -438,7 +388,7 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph, "FInferStorageType", "storage_type_inputs", "storage_type_attr_key", "storage_type", "storage_type_num_unknown_nodes", [](const int t) { return t == -1; }, - DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable); + common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable); // log the storage types and dispatch modes of the graph bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 966a753..044ab82 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -120,9 +120,13 @@ inline void SetShapeType(const Context& ctx, for (auto& i : outputs) { out_types.push_back(i->dtype()); } - CHECK(infertype.count(attrs.op)) - << "Operator " << attrs.op->name << " is missing FInferType attribute"; - CHECK(infertype[attrs.op](attrs, &in_types, &out_types)); + bool infer_type_success = false; + if (infertype.count(attrs.op)) { + infer_type_success = infertype[attrs.op](attrs, &in_types, &out_types); + } else { + infer_type_success = common::SameType(attrs, &in_types, &out_types); + } + CHECK(infer_type_success) << "Operator " << attrs.op->name << " is missing FInferType attribute"; CHECK_EQ(out_types.size(), outputs.size()); // infer storage type @@ -138,13 +142,13 @@ inline void SetShapeType(const Context& ctx, for (auto& i : outputs) { out_storage_types.push_back(i->storage_type()); } - bool infer_stype_success; + bool infer_stype_success = false; if (inferstorage.count(attrs.op)) { infer_stype_success = inferstorage[attrs.op](attrs, ctx.dev_mask(), dispatch_mode, &in_storage_types, &out_storage_types); } else { // if infer storage attr is not present, apply the default infer storage function - infer_stype_success = exec::DefaultStorageType(attrs, ctx.dev_mask(), dispatch_mode, + infer_stype_success = common::DefaultStorageType(attrs, ctx.dev_mask(), dispatch_mode, &in_storage_types, &out_storage_types); } CHECK(infer_stype_success) << "Operator not implemented: " -- To stop receiving notification emails like this one, please contact j...@apache.org.