zhiics commented on a change in pull request #5616:
URL: https://github.com/apache/incubator-tvm/pull/5616#discussion_r427416869
##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -404,21 +404,96 @@ IRModule RemoveDefaultAnnotations(IRModule module) {
return module;
}
+/*! \brief There can be regions with multiple outputs where each output
+ * could be a tuple output. Such tuple outputs needs to be flattened
+ * otherwise the function would create tuples of tuples.
+ */
+
+// New annotations would be required to be added for each flattened output
+const PackedFunc* make_end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+
+IRModule FlattenTupleOutputs(IRModule module) {
+ class TupleOutFlattener : public ExprRewriter {
+ public:
+ TupleOutFlattener() = default;
+
+ Expr InsertAnnotation(const Expr& expr, const std::string& target, const
PackedFunc* ann_op) {
+ Expr new_op = (*ann_op)(expr, target);
+ new_op->checked_type_ = expr->checked_type_;
+ return new_op;
+ }
+
+ Expr Rewrite_(const CallNode* call, const Expr& post) final {
+ if (call->op == compiler_end_op) {
+ std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+ // Arguments of annotation ops should be 1
+ CHECK_EQ(call->args.size(), 1U);
+ auto annotated_op = Downcast<Call>(post)->args[0];
+ if (annotated_op->IsInstance<TupleNode>()) {
+ auto tn = annotated_op.as<TupleNode>();
+ Array<Expr> new_fields;
+
+ // Here each input of the tuple will be annotated with compiler_ends
+ for (auto& tn_arg : tn->fields) {
+ auto nf = InsertAnnotation(tn_arg, target, make_end_op);
+ new_fields.push_back(nf);
+ }
+
+ // Return a tuple of compiler_ends in the place of the tuple that was
+ // annotated with a compiler_end.
+ auto out = Tuple(new_fields);
+ return std::move(out);
+ }
+ }
+ return post;
+ }
+ };
+
+ auto glob_funcs = module->functions;
+ // module is mutable, hence, we make a copy of it.
+ module.CopyOnWrite();
+ for (const auto& pair : glob_funcs) {
+ if (auto* fn = pair.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(fn);
+ TupleOutFlattener to_flattener;
+ auto removed = PostOrderRewrite(func->body, &to_flattener);
+ func = Function(func->params, removed, func->ret_type,
func->type_params, func->attrs);
+ module->Update(pair.first, func);
+ }
+ }
+ return module;
+}
+
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m,
-
PassContext pc) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> flatten_tuples =
[=](IRModule m,
+
PassContext pc) {
+ // There could be compiler_end annotations on tuples
+ // If the corresponding region is having multiple compiler_ends,
+ // this would lead to creation of tuples of tuples.
+ // Thus, we flatten the tuples by transfering the compiler_end to
+ // the tuple inputs.
+ return partitioning::FlattenTupleOutputs(m);
+ };
+
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> remove_defaults =
[=](IRModule m,
+
PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with
"default" attribute
// by treating them as un-annotated, but we don't have it yet. This
workaround pass removes
// all "default" annotations and should be deleted in the future.
- auto new_m = partitioning::RemoveDefaultAnnotations(m);
- return partitioning::Partitioner(new_m).Partition();
+ return partitioning::RemoveDefaultAnnotations(m);
};
- auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
- return Sequential({partitioned, InferType()});
+
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
+ [=](IRModule m, PassContext pc) { return
partitioning::Partitioner(m).Partition(); };
+
+ auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0,
"FlattenNestedTuples", {});
+ auto remove_default_pass = CreateModulePass(remove_defaults, 0,
"RemoveDefaultAnnotations", {});
+ auto partition_pass = CreateModulePass(part_func, 0, "PartitionGraph", {});
+ return Sequential({flatten_tuples_pass, remove_default_pass, partition_pass,
InferType()});
Review comment:
Is this order correct? Why partition_pass should be after the previous
two?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]