This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 59b5f6d9ae [Unity][BYOC] Make `MergeCompositeFunctions` append codegen 
name at the end of function name (#15534)
59b5f6d9ae is described below

commit 59b5f6d9ae0d805ec3afac7aae49645def51056b
Author: Sunghyun Park <[email protected]>
AuthorDate: Sat Aug 12 16:26:46 2023 -0700

    [Unity][BYOC] Make `MergeCompositeFunctions` append codegen name at the end 
of function name (#15534)
    
    * append codegen name at the end of the funcname
    
    * lint
---
 src/relax/transform/merge_composite_functions.cc   |  81 ++++++++++---
 .../relax/test_transform_fuse_ops_by_pattern.py    |   4 +-
 .../test_transform_merge_composite_functions.py    | 134 ++++++++-------------
 3 files changed, 118 insertions(+), 101 deletions(-)

diff --git a/src/relax/transform/merge_composite_functions.cc 
b/src/relax/transform/merge_composite_functions.cc
index 365b34c601..9d9d9aa644 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -319,6 +319,66 @@ class CompositeInliner : public ExprMutator {
   Map<Function, Function> inlined_functions_;
 };
 
+/*!
+ * \brief Wrap each created composite function with another function, whose 
body consists
+ * only of a call to the composite function, and annotate the outer function 
with kCodegen
+ * and kGlobalSymbol attributes.
+ */
+class CompositeFunctionAnnotator : public ExprMutator {
+ public:
+  explicit CompositeFunctionAnnotator(IRModule mod, IRModule new_mod)
+      : ExprMutator(new_mod), mod_(new_mod), inliner(mod) {
+    mod_.CopyOnWrite();
+  }
+  using ExprMutator::VisitExpr_;
+
+  IRModule update() {
+    auto gvar = mod_->GetGlobalVar("main");
+    auto func = Downcast<Function>(mod_->Lookup(gvar));
+    builder_->UpdateFunction(gvar, Downcast<Function>(VisitExpr(func)));
+    return builder_->GetContextIRModule();
+  }
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (call->op->IsInstance<GlobalVarNode>()) {
+      GlobalVar cur_var = Downcast<GlobalVar>(call->op);
+      auto func = Downcast<Function>(mod_->Lookup(cur_var));
+      if (auto codegen_name = func->GetAttr<String>(attr::kCodegen)) {
+        GlobalVar new_var;
+        if (var_map_.count(cur_var) > 0) {
+          // if we visited before, we don't need to create the new function,
+          // use the one we stored.
+          new_var = var_map_[cur_var];
+        } else {
+          // if it is first time, create the new function with a new name.
+          // remove old function from the irmoulde under construction.
+          auto old_var = 
builder_->GetContextIRModule()->GetGlobalVar(cur_var->name_hint);
+          builder_->GetContextIRModule()->Remove(old_var);
+
+          // rename the function.
+          String new_func_name = cur_var->name_hint + "_" + 
codegen_name.value();
+          Function new_func = inliner.Run(Downcast<Function>(func));
+          new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, 
new_func_name);
+          new_func = WithoutAttr(std::move(new_func), 
tvm::relax::attr::kPrimitive);
+          // add a function with a new name.
+          new_var = builder_->AddFunction(new_func, new_func_name);
+          var_map_[cur_var] = new_var;
+        }
+        // we call new var instead of the old one.
+        // we don't have to update args since we are just updating the 
function to call,
+        // without any change in the arguments.
+        return Call(new_var, call->args);
+      }
+    }
+    return GetRef<Call>(call);
+  }
+
+ private:
+  IRModule mod_;
+  CompositeInliner inliner;
+  std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual> 
var_map_;
+};
+
 }  // namespace
 
 IRModule MergeCompositeFunctions(IRModule mod) {
@@ -327,21 +387,8 @@ IRModule MergeCompositeFunctions(IRModule mod) {
   support::Arena arena;
   auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func);
   auto new_mod = MakeGroupedFunctions(mod, group_map);
+  new_mod = CompositeFunctionAnnotator(mod, new_mod).update();
 
-  CompositeInliner inliner(mod);
-  std::vector<std::pair<GlobalVar, BaseFunc>> to_update;
-  for (const auto& [gvar, func] : new_mod->functions) {
-    if (func->GetAttr<String>(attr::kCodegen)) {
-      auto new_func = inliner.Run(Downcast<Function>(func));
-      new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint);
-      new_func = WithoutAttr(std::move(new_func), 
tvm::relax::attr::kPrimitive);
-      to_update.emplace_back(gvar, new_func);
-    }
-  }
-
-  for (const auto& [gvar, func] : to_update) {
-    new_mod->Update(gvar, Downcast<Function>(func));
-  }
   // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better 
way to handle this.
   return DeadCodeElimination(new_mod, {"main"});
 }
@@ -351,9 +398,9 @@ namespace transform {
 Pass MergeCompositeFunctions() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule mod, PassContext pc) { return 
relax::MergeCompositeFunctions(mod); };
-  return CreateModulePass(/*pass_function=*/pass_func,       //
-                          /*opt_level=*/0,                   //
-                          /*pass_name=*/"FuseOpsByPattern",  //
+  return CreateModulePass(/*pass_function=*/pass_func,              //
+                          /*opt_level=*/0,                          //
+                          /*pass_name=*/"MergeCompositeFunctions",  //
                           /*required=*/{});
 }
 
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index d95ed68c61..bbf2b0eea4 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -597,7 +597,7 @@ def test_compare_with_merge_composite_path():
     @I.ir_module
     class Expected2:
         @R.function
-        def fused_relax_multiply1(
+        def fused_relax_multiply1_cutlass(
             x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
         ) -> R.Tensor((10, 10), dtype="float32"):
             R.func_attr({"Codegen": "cutlass"})
@@ -624,7 +624,7 @@ def test_compare_with_merge_composite_path():
         ) -> R.Tensor((10, 10), dtype="float32"):
             cls = Expected2
             with R.dataflow():
-                gv: R.Tensor((10, 10), dtype="float32") = 
cls.fused_relax_multiply1(x, y)
+                gv: R.Tensor((10, 10), dtype="float32") = 
cls.fused_relax_multiply1_cutlass(x, y)
                 R.output(gv)
             return gv
 
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py 
b/tests/python/relax/test_transform_merge_composite_functions.py
index d56e1db564..6a36314a74 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -86,24 +86,19 @@ class Conv2dReLUx2_merged:
         with R.dataflow():
             gv: R.Tensor(
                 (1, 64, 54, 54), dtype="float32"
-            ) = 
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(
+            ) = 
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl(
                 data, weight1, weight2
             )
             R.output(gv)
         return gv
 
     @R.function
-    def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(
+    def 
fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
         weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-        R.func_attr(
-            {
-                "Codegen": "dnnl",
-                "global_symbol": 
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1",
-            }
-        )
+        R.func_attr({"Codegen": "dnnl"})
 
         @R.function
         def lv(
@@ -211,17 +206,12 @@ class Diamond:
 @tvm.script.ir_module
 class Diamond_merged:
     @R.function
-    def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(
+    def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
         data: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
         # function attr dict
-        R.func_attr(
-            {
-                "Codegen": "compiler_A",
-                "global_symbol": 
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add",
-            }
-        )
+        R.func_attr({"Codegen": "compiler_A"})
 
         # block 0
         @R.function
@@ -303,7 +293,9 @@ class Diamond_merged:
         with R.dataflow():
             gv5: R.Tensor(
                 (1, 64, 54, 54), dtype="float32"
-            ) = 
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(data2, weight2)
+            ) = 
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
+                data2, weight2
+            )
             R.output(gv5)
         return gv5
 
@@ -385,27 +377,26 @@ class Diamond_cyclic_dep_merged:
             lv4: R.Tuple(
                 R.Tensor((1, 64, 54, 54), dtype="float32"),
                 R.Tensor((1, 64, 54, 54), dtype="float32"),
-            ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data2, weight2)
+            ) = cls.fused_relax_nn_conv2d_relax_nn_relu_compiler_A(data2, 
weight2)
             lv12: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[0]
             lv22: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[1]
-            lv31: R.Tensor((1, 64, 54, 54), dtype="float32") = 
cls.fused_relax_nn_gelu1(lv12)
-            gv5: R.Tensor((1, 64, 54, 54), dtype="float32") = 
cls.fused_relax_add1(lv22, lv31)
+            lv31: R.Tensor((1, 64, 54, 54), dtype="float32") = 
cls.fused_relax_nn_gelu1_compiler_B(
+                lv12
+            )
+            gv5: R.Tensor((1, 64, 54, 54), dtype="float32") = 
cls.fused_relax_add1_compiler_A(
+                lv22, lv31
+            )
             R.output(gv5)
         return gv5
 
     @R.function
-    def fused_relax_nn_conv2d_relax_nn_relu(
+    def fused_relax_nn_conv2d_relax_nn_relu_compiler_A(
         data: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
     ) -> R.Tuple(
         R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), 
dtype="float32")
     ):
-        R.func_attr(
-            {
-                "Codegen": "compiler_A",
-                "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu",
-            }
-        )
+        R.func_attr({"Codegen": "compiler_A"})
 
         @R.function
         def lv(
@@ -439,10 +430,10 @@ class Diamond_cyclic_dep_merged:
         return (gv, gv11)
 
     @R.function
-    def fused_relax_nn_gelu1(
+    def fused_relax_nn_gelu1_compiler_B(
         lv2: R.Tensor((1, 64, 54, 54), dtype="float32")
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-        R.func_attr({"Codegen": "compiler_B", "global_symbol": 
"fused_relax_nn_gelu1"})
+        R.func_attr({"Codegen": "compiler_B"})
 
         @R.function
         def lv21(
@@ -458,11 +449,11 @@ class Diamond_cyclic_dep_merged:
         return gv3
 
     @R.function
-    def fused_relax_add1(
+    def fused_relax_add1_compiler_A(
         lv32: R.Tensor((1, 64, 54, 54), dtype="float32"),
         lv41: R.Tensor((1, 64, 54, 54), dtype="float32"),
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-        R.func_attr({"Codegen": "compiler_A", "global_symbol": 
"fused_relax_add1"})
+        R.func_attr({"Codegen": "compiler_A"})
 
         @R.function
         def lv33(
@@ -529,16 +520,11 @@ class MultipleProducers:
 @tvm.script.ir_module
 class MultipleProducers_merged:
     @R.function
-    def 
fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(
+    def 
fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
         x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,), 
dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
         # function attr dict
-        R.func_attr(
-            {
-                "Codegen": "compiler_A",
-                "global_symbol": 
"fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add",
-            }
-        )
+        R.func_attr({"Codegen": "compiler_A"})
 
         # block 0
         @R.function
@@ -590,7 +576,7 @@ class MultipleProducers_merged:
         with R.dataflow():
             gv4: R.Tensor(
                 (10,), dtype="float32"
-            ) = 
cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(
+            ) = 
cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
                 x12, x22
             )
             R.output(gv4)
@@ -647,18 +633,20 @@ class MultipleProducersCyclic_merged:
     def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
         cls = MultipleProducersCyclic_merged
         with R.dataflow():
-            lv: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu1(x1)
+            lv: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_relu1_compiler_A(x1)
             lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv)
-            gv: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_gelu_relax_add(lv2, lv)
+            gv: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_gelu_relax_add_compiler_A(
+                lv2, lv
+            )
             R.output(gv)
         return gv
 
     @R.function
-    def fused_relax_nn_relu1(
+    def fused_relax_nn_relu1_compiler_A(
         x11: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
         # function attr dict
-        R.func_attr({"Codegen": "compiler_A", "global_symbol": 
"fused_relax_nn_relu1"})
+        R.func_attr({"Codegen": "compiler_A"})
 
         # block 0
         @R.function
@@ -675,16 +663,11 @@ class MultipleProducersCyclic_merged:
         return gv1
 
     @R.function
-    def fused_relax_nn_gelu_relax_add(
+    def fused_relax_nn_gelu_relax_add_compiler_A(
         lv21: R.Tensor((10,), dtype="float32"), lv11: R.Tensor((10,), 
dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
         # function attr dict
-        R.func_attr(
-            {
-                "Codegen": "compiler_A",
-                "global_symbol": "fused_relax_nn_gelu_relax_add",
-            }
-        )
+        R.func_attr({"Codegen": "compiler_A"})
         # block 0
 
         @R.function
@@ -770,17 +753,12 @@ class MergeCompilerRegionsExample:
 @tvm.script.ir_module
 class MergeCompilerRegionsExampleRef:
     @R.function
-    def fused_relax_add_relax_add_relax_nn_relu(
+    def fused_relax_add_relax_add_relax_nn_relu_compiler_A(
         x1: R.Tensor((10,), dtype="float32"),
         x2: R.Tensor((10,), dtype="float32"),
         lv: R.Tensor((10,), dtype="float32"),
     ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), 
dtype="float32")):
-        R.func_attr(
-            {
-                "Codegen": "compiler_A",
-                "global_symbol": "fused_relax_add_relax_add_relax_nn_relu",
-            }
-        )
+        R.func_attr({"Codegen": "compiler_A"})
 
         @R.function
         def lv1(
@@ -807,15 +785,10 @@ class MergeCompilerRegionsExampleRef:
         return (gv1, gv11)
 
     @R.function
-    def fused_relax_add_relax_nn_relu(
+    def fused_relax_add_relax_nn_relu_compiler_A(
         lv12: R.Tensor((10,), dtype="float32"), lv3: R.Tensor((10,), 
dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
-        R.func_attr(
-            {
-                "Codegen": "compiler_A",
-                "global_symbol": "fused_relax_add_relax_nn_relu",
-            }
-        )
+        R.func_attr({"Codegen": "compiler_A"})
 
         @R.function
         def lv21(
@@ -842,10 +815,10 @@ class MergeCompilerRegionsExampleRef:
         return gv3
 
     @R.function
-    def fused_relax_nn_gelu1(
+    def fused_relax_nn_gelu1_compiler_B(
         x3: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
-        R.func_attr({"Codegen": "compiler_B", "global_symbol": 
"fused_relax_nn_gelu1"})
+        R.func_attr({"Codegen": "compiler_B"})
 
         @R.function
         def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
@@ -866,14 +839,16 @@ class MergeCompilerRegionsExampleRef:
     ) -> R.Tensor((10,), dtype="float32"):
         cls = MergeCompilerRegionsExampleRef
         with R.dataflow():
-            lv5: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_gelu1(x32)
+            lv5: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_gelu1_compiler_B(x32)
             lv13: R.Tuple(
                 R.Tensor((10,), dtype="float32"), R.Tensor((10,), 
dtype="float32")
-            ) = cls.fused_relax_add_relax_add_relax_nn_relu(x12, x22, lv5)
+            ) = cls.fused_relax_add_relax_add_relax_nn_relu_compiler_A(x12, 
x22, lv5)
             lv23: R.Tensor((10,), dtype="float32") = lv13[0]
             lv32: R.Tensor((10,), dtype="float32") = lv13[1]
-            lv41: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_gelu1(lv23)
-            gv6: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_add_relax_nn_relu(lv41, lv32)
+            lv41: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_nn_gelu1_compiler_B(lv23)
+            gv6: R.Tensor((10,), dtype="float32") = 
cls.fused_relax_add_relax_nn_relu_compiler_A(
+                lv41, lv32
+            )
             R.output(gv6)
         return gv6
 
@@ -917,7 +892,7 @@ class ModuleWithNonComposite_ref:
     ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
         cls = ModuleWithNonComposite_ref
         with R.dataflow():
-            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
cls.fused_relax_nn_conv2d1(
+            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
cls.fused_relax_nn_conv2d1_tensorrt(
                 data, weight
             )
             conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
@@ -925,11 +900,11 @@ class ModuleWithNonComposite_ref:
         return conv
 
     @R.function
-    def fused_relax_nn_conv2d1(
+    def fused_relax_nn_conv2d1_tensorrt(
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
     ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
-        R.func_attr({"Codegen": "tensorrt", "global_symbol": 
"fused_relax_nn_conv2d1"})
+        R.func_attr({"Codegen": "tensorrt"})
 
         @R.function
         def lv1(
@@ -1076,17 +1051,12 @@ def test_reshape():
     @I.ir_module
     class Expected:
         @R.function
-        def fused_relax_reshape_relax_matmul(
+        def fused_relax_reshape_relax_matmul_tensorrt(
             inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
             param_0: R.Shape([1, 784]),
             lv1: R.Tensor((784, 512), dtype="float32"),
         ) -> R.Tensor((1, 512), dtype="float32"):
-            R.func_attr(
-                {
-                    "Codegen": "tensorrt",
-                    "global_symbol": "fused_relax_reshape_relax_matmul",
-                }
-            )
+            R.func_attr({"Codegen": "tensorrt"})
             # from tvm.script import relax as R
 
             @R.function
@@ -1128,9 +1098,9 @@ def test_reshape():
                 lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(
                     linear_relu_stack_0_weight, axes=None
                 )
-                gv: R.Tensor((1, 512), dtype="float32") = 
cls.fused_relax_reshape_relax_matmul(
-                    inp_0, R.shape([1, 784]), lv1
-                )
+                gv: R.Tensor(
+                    (1, 512), dtype="float32"
+                ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, 
R.shape([1, 784]), lv1)
                 R.output(gv)
             return gv
 

Reply via email to