This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2d7c673fcc [Unity] Add bind_constants option to FuseOpsByPattern 
(#14151)
2d7c673fcc is described below

commit 2d7c673fccb89ed11355726fd06a29c69474b4fd
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Mar 1 12:55:00 2023 -0800

    [Unity] Add bind_constants option to FuseOpsByPattern (#14151)
    
    * [Unity] Add lift_constatns option to FuseOpsByPattern
    
    * lift_constants -> bind_constants
---
 include/tvm/relax/transform.h                      |  4 +-
 python/tvm/relax/backend/contrib/cutlass.py        |  4 +-
 python/tvm/relax/transform/transform.py            |  7 ++-
 src/relax/transform/fuse_ops.cc                    | 12 +++--
 tests/python/relax/test_codegen_cutlass.py         |  2 +-
 .../relax/test_transform_fuse_ops_by_pattern.py    | 61 ++++++++++++++++++++--
 6 files changed, 77 insertions(+), 13 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 907f0bf8cd..715c8e56ff 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -225,6 +225,7 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
  * of priority in which they are matched. Higher-priority patterns should come 
earlier in the list.
  * \param checks The callback functions with type (Map<DFPattern, Expr>, Expr) 
-> bool. It takes a
  * match result and returns a boolean value to indicate whether the match 
result is accepted.
+ * \param bind_constants Whether or not to keep bound constants of the grouped 
function.
  * \param annotate_codegen If true, wrap each created composite function with 
another function,
  * whose body consists only of a call to the composite function, and annotate 
the outer function
  * with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set 
as the prefix of the
@@ -235,7 +236,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
  */
 TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
                               const tvm::Array<DFPattern>& patterns,
-                              const tvm::Array<PackedFunc>& checks, bool 
annotate_codegen = false);
+                              const tvm::Array<PackedFunc>& checks, bool 
bind_constants = true,
+                              bool annotate_codegen = false);
 
 /*!
  * \brief Group one or multiple composite functions created by 
FuseOpsByPattern into a new
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 51684abb06..e98194ca21 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -117,4 +117,6 @@ def partition_for_cutlass(mod):
     """
 
     cutlass_patterns = get_patterns_with_prefix("cutlass")
-    return transform.FuseOpsByPattern(cutlass_patterns, 
annotate_codegen=True)(mod)
+    return transform.FuseOpsByPattern(cutlass_patterns, bind_constants=True, 
annotate_codegen=True)(
+        mod
+    )
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 97daae4941..a33ad63093 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -283,7 +283,7 @@ def FuseTIR() -> tvm.ir.transform.Pass:
 
 
 def FuseOpsByPattern(
-    patterns: List[Tuple], annotate_codegen: bool = False
+    patterns: List[Tuple], bind_constants: bool = True, annotate_codegen: bool 
= False
 ) -> tvm.ir.transform.Pass:
     """Apply pattern matching to each function in the given module, and group 
matched expressions
     into a new function.
@@ -302,6 +302,9 @@ def FuseOpsByPattern(
         The string is the name of the corresponding pattern. It becomes the 
value of the kComposite
         attribute of a fused function after a successful matching.
 
+    bind_constants : bool
+        Whether or not to keep bound constants in the grouped function.
+
     annotate_codegen : bool
         If True, wrap each created composite function with another function, 
whose body consists
         only of a call to the composite function, and annotate the outer 
function with "Codegen"
@@ -332,7 +335,7 @@ def FuseOpsByPattern(
         else:
             raise ValueError("Invalid pattern: {}".format(tup))
     return _ffi_api.FuseOpsByPattern(
-        pattern_names, df_patterns, checks, annotate_codegen
+        pattern_names, df_patterns, checks, bind_constants, annotate_codegen
     )  # type: ignore
 
 
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 60b2c77e49..72427a8a0e 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -574,6 +574,8 @@ class OperatorFusor : public ExprMutator {
    * \param mod The IRModule to be transformed
    * \param graph The indexed-forward graph of the input IRModule
    * \param groups The grouped result of the group partition on the input 
indexed-forward graph.
+   * \param lift_constant Whether or not to lift bound constants to parameters 
of the grouped
+   * function.
    */
   OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const 
std::vector<Group*>& groups,
                 bool lift_constant = true)
@@ -1052,7 +1054,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
 IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
                           const tvm::Array<DFPattern>& patterns,
                           const tvm::Array<runtime::PackedFunc>& checks, 
IRModule mod,
-                          bool annotate_codegen) {
+                          bool bind_constants, bool annotate_codegen) {
   support::Arena arena;
   for (size_t i = 0; i < pattern_names.size(); ++i) {
     OperatorFusor::GroupMap group_map;
@@ -1064,7 +1066,7 @@ IRModule FuseOpsByPattern(const tvm::Array<String>& 
pattern_names,
                                               entry.second, &arena);
       group_map.insert(map.begin(), map.end());
     }
-    mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false);
+    mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ 
!bind_constants);
   }
   if (annotate_codegen) {
     return CompositeFunctionAnnotator(mod).Run();
@@ -1091,10 +1093,12 @@ 
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
 
 Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
                       const tvm::Array<DFPattern>& patterns,
-                      const tvm::Array<runtime::PackedFunc>& checks, bool 
annotate_codegen) {
+                      const tvm::Array<runtime::PackedFunc>& checks, bool 
bind_constants,
+                      bool annotate_codegen) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule m, PassContext pc) {
-        return relax::FuseOpsByPattern(pattern_names, patterns, checks, m, 
annotate_codegen);
+        return relax::FuseOpsByPattern(pattern_names, patterns, checks, m, 
bind_constants,
+                                       annotate_codegen);
       };
   return CreateModulePass(/*pass_function=*/pass_func,       //
                           /*opt_level=*/0,                   //
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index bce8c5a84f..6eb476496c 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -100,7 +100,7 @@ def get_result_with_relax_cutlass_offload(mod, *args):
 
     seq = tvm.transform.Sequential(
         [
-            relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True),
+            relax.transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True),
             relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}}),
         ]
     )
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index a9e76feb6c..b4b5591c01 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -389,8 +389,8 @@ conv2d_pat = 
make_fused_bias_activation_pattern("relax.nn.conv2d", activation=No
 conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
activation="relax.nn.relu")
 
 
-def check(mod, patterns, expected, annoatate_codegen=False):
-    partitioned = relax.transform.FuseOpsByPattern(patterns, 
annoatate_codegen)(mod)
+def check(mod, patterns, expected, bind_constants=True, 
annoatate_codegen=False):
+    partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, 
annoatate_codegen)(mod)
     tvm.ir.assert_structural_equal(partitioned, expected)
 
 
@@ -424,7 +424,9 @@ def test_cyclic_dependency():
     add_pat = is_op("relax.add")(relu_pat, wildcard())
 
     with pytest.raises(tvm.error.TVMError) as err:
-        relax.transform.FuseOpsByPattern([("compiler_A.conv2d_relu_add", 
add_pat)])(Branch)
+        relax.transform.FuseOpsByPattern(
+            [("compiler_A.conv2d_relu_add", add_pat)], bind_constants=True
+        )(Branch)
 
     assert "A cyclic dependency detected" in str(err.value)
 
@@ -434,7 +436,9 @@ def test_bind_params():
     mod = tvm.transform.Sequential(
         [
             relax.transform.BindParams("main", {"weight1": weight_np}),
-            relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", 
conv2d_relu_pat)]),
+            relax.transform.FuseOpsByPattern(
+                [("dnnl.conv2d_relu", conv2d_relu_pat)], bind_constants=True
+            ),
         ]
     )(Conv2dReLU)
 
@@ -589,5 +593,54 @@ def test_check_pattern():
     check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2)  # expect no 
partitioning
 
 
+def test_bind_constants():
+    weight = np.random.randn(64, 64, 3, 3).astype("float32")
+
+    @I.ir_module
+    class Conv2dWithConstantWeight:
+        @R.function
+        def main(
+            data: R.Tensor((1, 64, 56, 56), "float32"),
+            weight1: R.Tensor((64, 64, 3, 3), "float32"),
+        ):
+            with R.dataflow():
+                conv1 = R.nn.conv2d(data, R.const(weight), padding=(1, 1))
+                R.output(conv1)
+            return conv1
+
+    @I.ir_module
+    class Conv2dWithConstantWeight_partitioned:
+        @R.function
+        def fused_relax_nn_conv2d(
+            data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            param_0: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+            with R.dataflow():
+                gv = R.nn.conv2d(data, param_0, padding=(1, 1))
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
fused_relax_nn_conv2d(
+                    data, R.const(weight)
+                )
+                R.output(gv)
+            return gv
+
+    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
+    check(
+        Conv2dWithConstantWeight,
+        [("cutlass.conv2d", pat)],
+        Conv2dWithConstantWeight_partitioned,
+        bind_constants=False,
+    )
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to