This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f62445cdd9 [Relax] Disable fusion for fetching from the packed params 
in FuseOps (#17198)
f62445cdd9 is described below

commit f62445cdd96a415d332585aa9702eaf1df3cf972
Author: Wuwei Lin <[email protected]>
AuthorDate: Sun Jul 28 13:57:09 2024 -0700

    [Relax] Disable fusion for fetching from the packed params in FuseOps 
(#17198)
    
    * [Relax] Disable fusion for fetching from the packed params in FuseOps
    
    The order of bindings in the fusion result is determined by the first
    binding in each partition group. When the packed param tuple is used,
    the function usually begins with a numbers of `TupleGetItem` to unpack
    the param tuple. Previously `TupleGetItem` is treated as `kInjective`,
    this causes any operation that relies purely on these params to be
    moved to the beginning of the function and increases the memory usage
    of the intermediate results.
    
    * lint
---
 src/relax/transform/fuse_ops.cc               | 19 +++++++++--
 tests/python/relax/test_transform_fuse_ops.py | 48 +++++++++++++++++++++++++++
 2 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 6030a28d93..e791aeab06 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -147,6 +147,12 @@ class GraphCreator : public ExprVisitor {
       SetNodePattern(param_node, OpPatternKind::kOpaque);
       AddToPostDFSOrder(param_node, param.get());
     }
+    if (auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput)) {
+      for (int i = static_cast<int>(opt_num_input.value()->value);
+           i < static_cast<int>(func->params.size()); ++i) {
+        input_params_.insert(func->params[i].get());
+      }
+    }
     ExprVisitor::VisitExpr_(func);
   }
 
@@ -224,8 +230,15 @@ class GraphCreator : public ExprVisitor {
                          IndexedForwardGraph::Node* binding_var_node) {
     ICHECK_NOTNULL(binding_var_node);
 
-    SetNodePattern(binding_var_node, OpPatternKind::kInjective);
-    VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective);
+    auto pattern = OpPatternKind::kInjective;
+    if (input_params_.count(tuple_item->tuple.as<VarNode>())) {
+      // TupleGetItem for fetching the parameter from the packed param tuple 
is treated as opaque
+      // and won't be fused. This prevents the usage of packed param tuple 
changes the order of the
+      // fusion result as the function usually begins with fetching the 
parameters.
+      pattern = OpPatternKind::kOpaque;
+    }
+    SetNodePattern(binding_var_node, pattern);
+    VisitLeaf(tuple_item->tuple, binding_var_node, pattern);
   }
 
   void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* 
binding_var_node) {
@@ -354,6 +367,8 @@ class GraphCreator : public ExprVisitor {
   IndexedForwardGraph graph_;
   /*! \brief The graph nodes whose patterns are set */
   std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
+  /*! \brief The model params in the function input */
+  std::unordered_set<const VarNode*> input_params_;
 };
 
 /*!
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 3cd608d8ee..17bf586132 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1642,5 +1642,53 @@ def test_call_tir_inplace():
     _check(Module, Expected)
 
 
+def test_packed_params():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: 
T.Buffer((T.int64(16), T.int64(16)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1 in T.grid(T.int64(16), T.int64(16)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(lv[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.Cast("float32", lv[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: 
T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: 
T.Buffer((T.int64(16), T.int64(16)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)):
+                with T.block("T_matmul"):
+                    v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
+                    T.reads(x[v_ax0, v_k], lv2[v_k, v_ax1])
+                    T.writes(T_matmul[v_ax0, v_ax1])
+                    with T.init():
+                        T_matmul[v_ax0, v_ax1] = T.float32(0)
+                    T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + x[v_ax0, 
v_k] * lv2[v_k, v_ax1]
+
+        @R.function
+        def main(x: R.Tensor((16, 16), dtype="float32"), packed_params: 
R.Tuple(R.Tensor((16, 16), dtype="float16"), R.Tensor((16, 16), 
dtype="float16"))) -> R.Tensor((16, 16), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                lv: R.Tensor((16, 16), dtype="float16") = packed_params[0]
+                lv1: R.Tensor((16, 16), dtype="float16") = packed_params[1]
+                lv2 = R.call_tir(cls.cast, (lv,), out_sinfo=R.Tensor((16, 16), 
dtype="float32"))
+                lv3 = R.call_tir(cls.matmul, (x, lv2), out_sinfo=R.Tensor((16, 
16), dtype="float32"))
+                lv4 = R.call_tir(cls.cast, (lv1,), out_sinfo=R.Tensor((16, 
16), dtype="float32"))
+                lv5 = R.call_tir(cls.matmul, (lv3, lv4), 
out_sinfo=R.Tensor((16, 16), dtype="float32"))
+                gv: R.Tensor((16, 16), dtype="float32") = lv5
+                R.output(gv)
+            return gv
+    # fmt: on
+
+    Expected = Before
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to