This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 15ba19fa78 [Unity][BYOC] Assign group to unused bindings and ignroe 
PrimFunc (#14139)
15ba19fa78 is described below

commit 15ba19fa78148e6d9146fdba4539a0d9ba1dbf47
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Feb 27 12:20:39 2023 -0800

    [Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc (#14139)
    
    * [Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc
    
    * Update fuse_ops.cc
---
 src/relax/transform/fuse_ops.cc                    |  46 ++++----
 src/relax/transform/run_codegen.cc                 |   3 +
 .../relax/test_transform_fuse_ops_by_pattern.py    | 121 ++++++++++++++++++++-
 3 files changed, 144 insertions(+), 26 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 813c0c8f03..c5042d0191 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -890,14 +890,6 @@ IRModule MakeGroupedFunctions(
   return OperatorFusor(mod, partition, lift_constants).Transform();
 }
 
-static Map<Expr, Var> GetBindingInverse(const Map<Var, Expr>& binding) {
-  Map<Expr, Var> value_to_bound_var;
-  for (const auto& [var, val] : binding) {
-    value_to_bound_var.Set(val, var);
-  }
-  return value_to_bound_var;
-}
-
 /*! \brief Create a "partitioning", a map from interior / leaf expr to its 
representative group,
  * based on the provided pattern. The result can be passed to OperatorFusor 
above to fuse operations
  * in a group and create a grouped function.
@@ -909,21 +901,26 @@ class PatternBasedPartitioner : ExprVisitor {
   using ExprVisitor::VisitExpr_;
 
   static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr, 
support::Arena* arena) {
-    PatternBasedPartitioner part(pattern_name, pattern, 
AnalyzeVar2Value(expr));
-    // Initialize each expr to have its own group
-    PostOrderVisit(
-        expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] = 
arena->make<Group>(); });
+    PatternBasedPartitioner part(pattern_name, pattern, arena);
     part.VisitExpr(expr);
     return part.group_map_;
   }
 
-  PatternBasedPartitioner(String pattern_name, DFPattern pattern, const 
Map<Var, Expr>& bindings)
-      : pat_name_(pattern_name),
-        pat_(pattern),
-        bindings_(bindings),
-        value_to_bound_var_(GetBindingInverse(bindings)) {}
+  PatternBasedPartitioner(String pattern_name, DFPattern pattern, 
support::Arena* arena)
+      : pat_name_(pattern_name), pat_(pattern), arena_(arena) {}
+
+  void VisitVarDef(const Var& var) final { group_map_[var.get()] = 
arena_->make<Group>(); }
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    bindings_.Set(binding->var, binding->value);
+    value_to_bound_var_.Set(binding->value, binding->var);
+    ExprVisitor::VisitBinding_(binding);
+  }
+
+  void VisitExpr_(const ConstantNode* op) final { group_map_[op] = 
arena_->make<Group>(); }
 
-  void VisitExpr_(const CallNode* call) override {
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
+    VisitVarDef(binding->var);
     if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), 
bindings_)) {
       // If a match is found, put all matching expressions into the same group.
       // OperatorFusor also requires that the bound variable be in the same 
group as the RHS value.
@@ -939,15 +936,12 @@ class PatternBasedPartitioner : ExprVisitor {
       //   conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
 
       // parent_group corresponds to the group of "conv1" above.
-      auto parent_group = GetGroupForBoundVar(GetRef<Call>(call));
+      auto parent_group = GetGroupForBoundVar(binding->var);
       ICHECK(parent_group);
       parent_group->attrs.Set(attr::kComposite, pat_name_);
-
       for (const auto& [pat, match] : matches_opt.value()) {
-        ICHECK(group_map_.count(match.get()));
         // Put all matching call nodes into the parent group.
         if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call)) 
{
-          AddToGroup(match, parent_group);
           // Put the bound variable on the LHS into the same parent group.
           AddToGroup(value_to_bound_var_[match], parent_group);
         }
@@ -964,15 +958,14 @@ class PatternBasedPartitioner : ExprVisitor {
     }
   }
 
-  Group* GetGroupForBoundVar(Expr e) {
-    ICHECK(value_to_bound_var_.count(e));
-    auto bound_var = value_to_bound_var_[e];
+  Group* GetGroupForBoundVar(const Var& bound_var) {
     ICHECK(group_map_.count(bound_var.get()));
     return group_map_[bound_var.get()]->FindRoot();
   }
 
   String pat_name_;
   DFPattern pat_;
+  support::Arena* arena_;
   Map<Var, Expr> bindings_;
   Map<Expr, Var> value_to_bound_var_;
   GroupMap group_map_;
@@ -1055,6 +1048,9 @@ IRModule FuseOpsByPattern(const tvm::Array<String>& 
pattern_names,
   for (size_t i = 0; i < pattern_names.size(); ++i) {
     OperatorFusor::GroupMap group_map;
     for (const auto& entry : mod->functions) {
+      if (entry.second->IsInstance<tir::PrimFuncNode>()) {
+        continue;
+      }
       auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], 
entry.second, &arena);
       group_map.insert(map.begin(), map.end());
     }
diff --git a/src/relax/transform/run_codegen.cc 
b/src/relax/transform/run_codegen.cc
index 114b7d2a34..7deeb139d1 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -138,6 +138,9 @@ class CodeGenRunner : ExprMutator {
     std::unordered_map<std::string, Array<Function>> target_functions;
 
     for (const auto& entry : mod->functions) {
+      if (entry.second->IsInstance<tir::PrimFuncNode>()) {
+        continue;
+      }
       PostOrderVisit(entry.second, [&target_functions](Expr e) {
         if (e->IsInstance<FunctionNode>()) {
           auto f = Downcast<Function>(e);
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index da5b92fb64..21f952096b 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -20,7 +20,7 @@ import numpy as np
 import tvm
 
 from tvm import relax
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T, ir as I
 from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, 
wildcard
 
 
@@ -460,5 +460,124 @@ def test_multiple_calls_same_extern():
     check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, 
annoatate_codegen=True)
 
 
+def test_ignore_call_tir():
+    @I.ir_module
+    class Conv2dReLUCallTIR:
+        @T.prim_func
+        def relu(
+            data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 
64, 56, 56), "float32")
+        ):
+            for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+                with T.block("root"):
+                    i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    out[i, j, k, l] = T.max(data[i, j, k, l], 0.0)
+
+        @R.function
+        def main(
+            data: R.Tensor((1, 64, 56, 56), "float32"),
+            weight1: R.Tensor((64, 64, 3, 3), "float32"),
+        ):
+            with R.dataflow():
+                conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
+                relu1 = R.call_tir(relu, (conv1,), R.Tensor((64, 64, 56, 56), 
"float32"))
+                R.output(relu1)
+
+            return relu1
+
+    @I.ir_module
+    class Conv2dReLUCallTIR_partitioned:
+        @T.prim_func
+        def relu(
+            data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 
64, 56, 56), "float32")
+        ):
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+                with T.block("root"):
+                    i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(data[i, j, k, l])
+                    T.writes(out[i, j, k, l])
+                    out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))
+
+        @R.function
+        def fused_relax_nn_conv2d(
+            data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                    data,
+                    weight1,
+                    padding=(1, 1),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
fused_relax_nn_conv2d(
+                    data, weight1
+                )
+                relu1 = R.call_tir(
+                    relu, (lv,), out_sinfo=R.Tensor((64, 64, 56, 56), 
dtype="float32")
+                )
+                R.output(relu1)
+            return relu1
+
+    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
+    check(Conv2dReLUCallTIR, [("cutlass.conv2d", pat)], 
Conv2dReLUCallTIR_partitioned)
+
+
+def test_unused():
+    @I.ir_module
+    class Conv2dReLU:
+        @R.function
+        def main(
+            data: R.Tensor((1, 64, 56, 56), "float32"),
+            weight1: R.Tensor((64, 64, 3, 3), "float32"),
+        ):
+            with R.dataflow():
+                conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
+                relu = R.nn.relu(data)
+                R.output(conv1)
+
+            return conv1
+
+    @I.ir_module
+    class Conv2dReLU_partitioned:
+        @R.function
+        def fused_relax_nn_conv2d(
+            data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                    data, weight1, padding=(1, 1)
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
fused_relax_nn_conv2d(
+                    data, weight1
+                )
+                relu: R.Tensor((1, 64, 56, 56), dtype="float32") = 
R.nn.relu(data)
+                R.output(gv)
+            return gv
+
+    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
+    check(Conv2dReLU, [("cutlass.conv2d", pat)], Conv2dReLU_partitioned)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to