This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fe04e9  add SameType as default type inference function in imperative 
mode (#9942)
0fe04e9 is described below

commit 0fe04e9d778ac9c55e933ff3f21c4ddf28a4a101
Author: Ziyue Huang <zyhuan...@gmail.com>
AuthorDate: Wed Mar 7 13:43:32 2018 +0800

    add SameType as default type inference function in imperative mode (#9942)
    
    * add SameType as default in imperative mode
    
    * move SameType and DefaultStorageType to src/common
---
 src/common/exec_utils.h               | 61 +++++++++++++++++++++++++++++++++++
 src/executor/exec_pass.h              | 11 -------
 src/executor/infer_graph_attr_pass.cc | 56 ++------------------------------
 src/imperative/imperative_utils.h     | 14 +++++---
 4 files changed, 73 insertions(+), 69 deletions(-)

diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index 5fd1a9b..29537d3 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -169,6 +169,67 @@ inline void CastNonDefaultStorage(const 
std::vector<NDArray>& src,
     }
   }
 }
+
+/*! \brief The default type inference function, which assigns all undefined
+ *         types to the same type of one of the inputs or outputs.
+ */
+inline bool SameType(const nnvm::NodeAttrs& attrs,
+                     std::vector<int> *iattr,
+                     std::vector<int> *oattr) {
+  int def_v = -1;
+  for (int v : *oattr) {
+    if (v != -1) {
+      def_v = v; break;
+    }
+  }
+  if (def_v == -1) {
+    for (int v : *iattr) {
+      if (v != -1) {
+        def_v = v; break;
+      }
+    }
+  }
+  if (def_v == -1) return false;
+  for (int& v : *oattr) {
+    v = def_v;
+  }
+  for (int& v : *iattr) {
+    v = def_v;
+  }
+  return true;
+}
+
+
+/*! \brief The default storage type inference function, which assigns all 
undefined
+ *         storage types to kDefaultStorage. If all of input and output 
storage types
+ *         are kDefaultStorage, DispatchMode::kFCompute is assigned to 
dispatch_mode. Otherwise,
+ *         DispatchMode::kFComputeFallback is assigned to dispatch_mode.
+ */
+inline bool DefaultStorageType(const nnvm::NodeAttrs& attrs,
+                               const int dev_mask,
+                               DispatchMode* dispatch_mode,
+                               std::vector<int> *iattr,
+                               std::vector<int> *oattr) {
+  bool fallback = false;
+  for (int& v : *oattr) {
+    if (v == -1) v = kDefaultStorage;
+    if (v != kDefaultStorage) fallback = true;
+  }
+  for (int& v : *iattr) {
+    if (v == -1) v = kDefaultStorage;
+    if (v != kDefaultStorage) fallback = true;
+  }
+  if (*dispatch_mode == DispatchMode::kUndefined) {
+    if (fallback) {
+      *dispatch_mode = DispatchMode::kFComputeFallback;
+    } else {
+      *dispatch_mode = DispatchMode::kFCompute;
+    }
+  }
+  return true;
+}
+
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_EXEC_UTILS_H_
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index bf4b147..99b1b16 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -178,17 +178,6 @@ Graph InferStorageType(Graph&& graph,
                        StorageTypeVector&& storage_type_inputs = 
StorageTypeVector(),
                        const std::string& storage_type_attr_key = "");
 
-/*! \brief The default storage type inference function, which assigns all 
undefined
- *         storage types to kDefaultStorage. If all of input and output 
storage types
- *         are kDefaultStorage, DispatchMode::kFCompute is assigned to 
dispatch_mode. Otherwise,
- *         DispatchMode::kFComputeFallback is assigned to dispatch_mode.
- */
-bool DefaultStorageType(const nnvm::NodeAttrs& attrs,
-                        const int dev_mask,
-                        DispatchMode* dispatch_mode,
-                        std::vector<int> *iattr,
-                        std::vector<int> *oattr);
-
 }  // namespace exec
 }  // namespace mxnet
 
diff --git a/src/executor/infer_graph_attr_pass.cc 
b/src/executor/infer_graph_attr_pass.cc
index 01fab22..191fbe9 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -26,6 +26,7 @@
 #include <mxnet/graph_attr_types.h>
 #include "./exec_pass.h"
 #include "../operator/operator_common.h"
+#include "../common/exec_utils.h"
 
 namespace mxnet {
 namespace exec {
@@ -321,57 +322,6 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
   return ret;
 }
 
-// inference fucntion for same type
-inline bool SameType(const nnvm::NodeAttrs& attrs,
-                     std::vector<int> *iattr,
-                     std::vector<int> *oattr) {
-  int def_v = -1;
-  for (int v : *oattr) {
-    if (v != -1) {
-      def_v = v; break;
-    }
-  }
-  if (def_v == -1) {
-    for (int v : *iattr) {
-      if (v != -1) {
-        def_v = v; break;
-      }
-    }
-  }
-  if (def_v == -1) return false;
-  for (int& v : *oattr) {
-    v = def_v;
-  }
-  for (int& v : *iattr) {
-    v = def_v;
-  }
-  return true;
-}
-
-inline bool DefaultStorageType(const nnvm::NodeAttrs& attrs,
-                               const int dev_mask,
-                               DispatchMode* dispatch_mode,
-                               std::vector<int> *iattr,
-                               std::vector<int> *oattr) {
-  bool fallback = false;
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
-    if (v != kDefaultStorage) fallback = true;
-  }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
-    if (v != kDefaultStorage) fallback = true;
-  }
-  if (*dispatch_mode == DispatchMode::kUndefined) {
-    if (fallback) {
-      *dispatch_mode = DispatchMode::kFComputeFallback;
-    } else {
-      *dispatch_mode = DispatchMode::kFCompute;
-    }
-  }
-  return true;
-}
-
 nnvm::Graph InferShape(nnvm::Graph&& graph,
                        nnvm::ShapeVector&& shape_inputs,
                        const std::string& shape_attr_key) {
@@ -405,7 +355,7 @@ nnvm::Graph InferType(nnvm::Graph&& graph,
       "FInferType", "dtype_inputs", "dtype_attr_key",
       "dtype", "dtype_num_unknown_nodes",
       [](const int t) { return t == -1; },
-      SameType, true, nullptr);
+      common::SameType, true, nullptr);
 }
 
 nnvm::Graph InferStorageType(nnvm::Graph&& graph,
@@ -438,7 +388,7 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph,
       "FInferStorageType", "storage_type_inputs", "storage_type_attr_key",
       "storage_type", "storage_type_num_unknown_nodes",
       [](const int t) { return t == -1; },
-      DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable);
+      common::DefaultStorageType, false, "dispatch_mode", 
DispatchMode::kVariable);
 
   // log the storage types and dispatch modes of the graph
   bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", 
false);
diff --git a/src/imperative/imperative_utils.h 
b/src/imperative/imperative_utils.h
index 966a753..044ab82 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -120,9 +120,13 @@ inline void SetShapeType(const Context& ctx,
   for (auto& i : outputs) {
     out_types.push_back(i->dtype());
   }
-  CHECK(infertype.count(attrs.op))
-    << "Operator " << attrs.op->name << " is missing FInferType attribute";
-  CHECK(infertype[attrs.op](attrs, &in_types, &out_types));
+  bool infer_type_success = false;
+  if (infertype.count(attrs.op)) {
+    infer_type_success = infertype[attrs.op](attrs, &in_types, &out_types);
+  } else {
+    infer_type_success = common::SameType(attrs, &in_types, &out_types);
+  }
+  CHECK(infer_type_success) << "Operator " << attrs.op->name << " is missing 
FInferType attribute";
   CHECK_EQ(out_types.size(), outputs.size());
 
   // infer storage type
@@ -138,13 +142,13 @@ inline void SetShapeType(const Context& ctx,
   for (auto& i : outputs) {
     out_storage_types.push_back(i->storage_type());
   }
-  bool infer_stype_success;
+  bool infer_stype_success = false;
   if (inferstorage.count(attrs.op)) {
     infer_stype_success = inferstorage[attrs.op](attrs, ctx.dev_mask(), 
dispatch_mode,
                                                  &in_storage_types, 
&out_storage_types);
   } else {
     // if infer storage attr is not present, apply the default infer storage 
function
-    infer_stype_success = exec::DefaultStorageType(attrs, ctx.dev_mask(), 
dispatch_mode,
+    infer_stype_success = common::DefaultStorageType(attrs, ctx.dev_mask(), 
dispatch_mode,
                                                    &in_storage_types, 
&out_storage_types);
   }
   CHECK(infer_stype_success) << "Operator not implemented: "

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to