comaniac commented on a change in pull request #6655:
URL: https://github.com/apache/incubator-tvm/pull/6655#discussion_r528990741



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -61,20 +67,27 @@ class AnnotateTargetRewriter : public ExprRewriter {
   std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
                                                    const std::string& target = 
"") {
     std::string ref_target = "";
+    Array<Expr> compiler_begins;
     Array<Expr> compiler_ends;
     for (auto arg : args) {
-      std::string arg_target = "default";
+      std::string arg_target = default_target;
       const CallNode* call = arg.as<CallNode>();
 
       if (call && call->op == CompilerBeginOp()) {
         // Argument is already compiler begin node meaning that this is not 
the first time
         // running this pass, so we simply remove it and will add a new one 
later.
         ICHECK_EQ(call->args.size(), 1U);
+        // Do not alter existing annotation if not default
+        if (default_target != call->attrs.as<CompilerAttrs>()->compiler) {
+          compiler_begins.push_back(arg);
+        } else {
+          // Remove default
+          compiler_ends.push_back(call->args[0]);
+        }

Review comment:
       Thanks for the explanation. I think I am now clearer about the first 
part, but we might need to improve the code comments a lot to make it clear to 
everyone reading this pass.
   
   For the second part, your understanding is correct. As you mentioned, the 
definition of "simple" is vague, but it's also more general. Since every target 
needs to specify their own transform pipeline (e.g., 
https://github.com/apache/incubator-tvm/blob/main/python/tvm/relay/op/contrib/arm_compute_lib.py#L44),
 we can every BYOC target can define their own "simple". For example, the 
definition of "simple" in ACL could be a subgraph with non-call ops; the 
definition of "simple" in TensorRT could be a subgraph without MAC ops. It 
means something like
   
   ```python
       seq = tvm.transform.Sequential(
           [
               transform.InferType(),
               transform.MergeComposite(arm_compute_lib_pattern_table()),
               transform.AnnotateTarget("arm_compute_lib"),
               
transform.PruneSimpleRegions(tvm._ffi.get_global_func("relay.ext.arm_compute_lib.prune")),
               transform.PartitionGraph(),
           ]
       )
   ```
   
   where `tvm._ffi.get_global_func("relay.ext.arm_compute_lib.prune")` is a 
packed function in C++ that accepts an `AnnotatedRegion` and outputs a boolean, 
indicating whether the region should be pruned or not.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to