masahi commented on a change in pull request #5493:
URL: https://github.com/apache/incubator-tvm/pull/5493#discussion_r418512702



##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -54,39 +54,27 @@ namespace partitioning {
 static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
 static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
 
-/*!
- * \brief The checker that verifies if a Relay program is annotated correctly
- * for partitioning.
+/*! \brief This struct maintains the required metadata for a region to 
generate a corresponding
+ * global function and function call. Global function will be passed to the 
target specific codegen
+ * and function call will be used in the transform Relay graph to invoke the 
function in runtime.
  */
-class AnnotationChecker : public ExprVisitor {
- public:
-  bool Check() {
-    if (!found_start_ && !found_end_) {
-      LOG(WARNING) << "No compiler annotation found";
-    } else if (!found_start_) {
-      LOG(ERROR) << "compiler_begin annotation is missing";
-      return false;
-    } else if (!found_end_) {
-      LOG(ERROR) << "compiler_end annotation is missing";
-      return false;
-    }
-    return true;
-  }
+struct RegionFuncMetadata {
+  /*! \brief The call node of the generated global function for this region. */
+  Call func_call;
 
-  void VisitExpr_(const CallNode* call) final {
-    auto op_node = call->op.as<OpNode>();
-    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
-      return;
-    } else if (call->op == compiler_begin_op) {
-      found_start_ = true;
-    } else if (call->op == compiler_end_op) {
-      found_end_ = true;
-    }
-  }
+  /*! \brief A list of argument pairs. Each pair includes (var, expr). var is 
used
+   * as a function node argument; input expression is used as a function call 
parameter.
+   */
+  std::vector<std::pair<Var, Expr>> args;
 
- private:
-  bool found_start_{false};
-  bool found_end_{false};
+  /*! \brief Map from each region output expr node to its output index and 
TupleGetItem node. */
+  std::unordered_map<Expr, std::pair<int, TupleGetItem>, ObjectHash, 
ObjectEqual> out_expr_indices;
+
+  /*! \brief Map from each input expression to the corresponding input 
variable of this region.
+   * This cache is used to make sure a region function will not have 
duplicated inputs even

Review comment:
       even if

##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -54,39 +54,27 @@ namespace partitioning {
 static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
 static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
 
-/*!
- * \brief The checker that verifies if a Relay program is annotated correctly
- * for partitioning.
+/*! \brief This struct maintains the required metadata for a region to 
generate a corresponding
+ * global function and function call. Global function will be passed to the 
target specific codegen
+ * and function call will be used in the transform Relay graph to invoke the 
function in runtime.
  */
-class AnnotationChecker : public ExprVisitor {
- public:
-  bool Check() {
-    if (!found_start_ && !found_end_) {
-      LOG(WARNING) << "No compiler annotation found";
-    } else if (!found_start_) {
-      LOG(ERROR) << "compiler_begin annotation is missing";
-      return false;
-    } else if (!found_end_) {
-      LOG(ERROR) << "compiler_end annotation is missing";
-      return false;
-    }
-    return true;
-  }
+struct RegionFuncMetadata {
+  /*! \brief The call node of the generated global function for this region. */
+  Call func_call;
 
-  void VisitExpr_(const CallNode* call) final {
-    auto op_node = call->op.as<OpNode>();
-    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
-      return;
-    } else if (call->op == compiler_begin_op) {
-      found_start_ = true;
-    } else if (call->op == compiler_end_op) {
-      found_end_ = true;
-    }
-  }
+  /*! \brief A list of argument pairs. Each pair includes (var, expr). var is 
used
+   * as a function node argument; input expression is used as a function call 
parameter.
+   */
+  std::vector<std::pair<Var, Expr>> args;
 
- private:
-  bool found_start_{false};
-  bool found_end_{false};
+  /*! \brief Map from each region output expr node to its output index and 
TupleGetItem node. */
+  std::unordered_map<Expr, std::pair<int, TupleGetItem>, ObjectHash, 
ObjectEqual> out_expr_indices;
+
+  /*! \brief Map from each input expression to the corresponding input 
variable of this region.
+   * This cache is used to make sure a region function will not have 
duplicated inputs even
+   * it refers the same expr multuple times.

Review comment:
       multiple




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to