comaniac commented on a change in pull request #6655:
URL: https://github.com/apache/incubator-tvm/pull/6655#discussion_r529032585



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -61,20 +67,27 @@ class AnnotateTargetRewriter : public ExprRewriter {
   std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
                                                    const std::string& target = 
"") {
     std::string ref_target = "";
+    Array<Expr> compiler_begins;
     Array<Expr> compiler_ends;
     for (auto arg : args) {
-      std::string arg_target = "default";
+      std::string arg_target = default_target;
       const CallNode* call = arg.as<CallNode>();
 
       if (call && call->op == CompilerBeginOp()) {
         // Argument is already compiler begin node meaning that this is not 
the first time
         // running this pass, so we simply remove it and will add a new one 
later.
         ICHECK_EQ(call->args.size(), 1U);
+        // Do not alter existing annotation if not default
+        if (default_target != call->attrs.as<CompilerAttrs>()->compiler) {
+          compiler_begins.push_back(arg);
+        } else {
+          // Remove default
+          compiler_ends.push_back(call->args[0]);
+        }

Review comment:
       IMHO, it should be fine, because we have abstracted the region to a 
separate data structure called "AnnotatedRegion". If we pass an AnnotatedRegion 
to PruneSimpleRegions, we just work on the region level. We can add an API such 
as "ChangeRegionTarget" to AnnotatedRegion that changes all compiler attributes 
of the nodes in that region. In this case, PruneSimpleRegion could just do (in 
Pseudo code):
   
   ```
   for region in regions:
       if region.target != "default" and is_simple_region(region):
           region.change_region_target("default")
   ```
   
   and one implementation of "is_simple_region" coud be like:
   
   ```
   def is_simple_region(region):
       class Checker(ExprVisitor):
           def __init__(self):
               self.is_simple = True
           def visit(self, call):
               self.is_simple = False
               super(self).visit(call)
   
       checker = Checker()
       for out in region.outs:
           checker.visit(out)
       return checker.is_simple
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to