comaniac commented on a change in pull request #6655:
URL: https://github.com/apache/incubator-tvm/pull/6655#discussion_r528990741
##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -61,20 +67,27 @@ class AnnotateTargetRewriter : public ExprRewriter {
std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
const std::string& target =
"") {
std::string ref_target = "";
+ Array<Expr> compiler_begins;
Array<Expr> compiler_ends;
for (auto arg : args) {
- std::string arg_target = "default";
+ std::string arg_target = default_target;
const CallNode* call = arg.as<CallNode>();
if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not
the first time
// running this pass, so we simply remove it and will add a new one
later.
ICHECK_EQ(call->args.size(), 1U);
+ // Do not alter existing annotation if not default
+ if (default_target != call->attrs.as<CompilerAttrs>()->compiler) {
+ compiler_begins.push_back(arg);
+ } else {
+ // Remove default
+ compiler_ends.push_back(call->args[0]);
+ }
Review comment:
Thanks for the explanation. I think I am now clearer about the first
part, but we might need to improve the code comments a lot to make it clear to
everyone reading this pass.
For the second part, your understanding is correct. As you mentioned, the
definition of "simple" is vague, but it's also more general. Since every target
needs to specify their own transform pipeline (e.g.,
https://github.com/apache/incubator-tvm/blob/main/python/tvm/relay/op/contrib/arm_compute_lib.py#L44),
we can every BYOC target can define their own "simple". For example, the
definition of "simple" in ACL could be a subgraph with non-call ops; the
definition of "simple" in TensorRT could be a subgraph without MAC ops. It
means something like
```python
seq = tvm.transform.Sequential(
[
transform.InferType(),
transform.MergeComposite(arm_compute_lib_pattern_table()),
transform.AnnotateTarget("arm_compute_lib"),
transform.PruneSimpleRegions(tvm._ffi.get_global_func("relay.ext.arm_compute_lib.prune")),
transform.PartitionGraph(),
]
)
```
where `tvm._ffi.get_global_func("relay.ext.arm_compute_lib.prune")` is a
packed function in C++ that accepts an `AnnotatedRegion` and outputs a boolean,
indicating whether the region should be pruned or not.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]