This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5bfca2e7a2 [Transform] Modify FuseTIR pass to propagate buffer 
attributes (#17075)
5bfca2e7a2 is described below

commit 5bfca2e7a25a357e5b3399ade98461a2678e8fc5
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Mon Jun 17 22:01:54 2024 +0530

    [Transform] Modify FuseTIR pass to propagate buffer attributes (#17075)
    
    Arguments of a fused TIR PrimFunc generated from a fused relax function do 
not retain all the buffer attributes from their original PrimFuncs as the 
buffers are created from the StructInfo of the Relax vars. This patch collects 
a mapping of relax vars to its corresponding TIR buffers in a fused relax 
function and uses that info to propagate its buffer attributes such as 
`axis_separators` and `storage_scope`
---
 src/relax/transform/fuse_tir.cc               | 140 ++++++++++++++++++++++----
 tests/python/relax/test_transform_fuse_tir.py | 128 +++++++++++++++++++++++
 2 files changed, 248 insertions(+), 20 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index e712b5022a..b203b322ab 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices 
is -1, but got " << i;
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ public:
+  explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+  static Map<Expr, tir::Buffer> Collect(const IRModule& mod, const Function& 
func) {
+    RelaxToTIRVarMapCollector visitor(mod);
+    visitor(func->body);
+    return visitor.relax_to_tir_var_map_;
+  }
+
+ private:
+  void VisitBinding_(const VarBindingNode* binding) final {
+    current_var_ = binding->var;
+    ExprVisitor::VisitBinding_(binding);
+  }
+
+  void VisitExpr_(const CallNode* call) {
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+
+    ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+        << "Only call_tir and call_tir_inplace are supported in primitive 
function, but got: "
+        << GetRef<Expr>(call);
+    CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_);
+  }
+
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+
+    // If the `expr` is already seen (present in the map), validate whether 
the mapped buffer is
+    // structurally equal to the `new_buf` passed
+    auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) {
+      if (auto it = relax_to_tir_var_map_.find(expr); it != 
relax_to_tir_var_map_.end()) {
+        ICHECK(StructuralEqual()((*it).second, new_buf))
+            << "Inconsistent buffers " << (*it).second << " and " << new_buf
+            << " mapped to the same relax var: " << expr;
+      }
+    };
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = tir_args[i];
+      if (auto tir_buffer = buffer_map.Get(tir_var)) {
+        if (i < num_inputs) {
+          const auto& relax_var = relax_args[i];
+          ValidateBufferCompatibility(tir_buffer.value(), relax_var);
+          relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
+        }
+        if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i);
+            it != output_idxs.end()) {
+          int result_idx = it - output_idxs.begin();
+          const auto& relax_var = relax_results[result_idx];
+          ValidateBufferCompatibility(tir_buffer.value(), relax_var);
+          relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
+        }
+      }
+    }
+  }
+
+ private:
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+  Map<Expr, tir::Buffer> relax_to_tir_var_map_;
+  Var current_var_;
+};
+
 class FusedTIRConstructor : public ExprVisitor {
  public:
   /*!
@@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor {
       : mod_(mod), func_name_(func_name) {}
 
   void VisitExpr_(const FunctionNode* func) final {
+    auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, 
GetRef<Function>(func));
     std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
     for (const Var& relax_param : func->params) {
       size_t size_before = prim_func_params.size();
-      CollectPrimFuncParams(relax_param, &prim_func_params);
+      CollectPrimFuncParams(relax_param, &prim_func_params, 
relax_to_tir_var_map.Get(relax_param));
 
       auto param_buffers = [&]() -> Array<tir::Buffer> {
         Array<tir::Buffer> out;
@@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor {
     MapArgsToBuffer(arg_list, buffer_list);
   }
 
-  static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
-                                                int num_inputs) {
-    Array<Integer> ret;
-    int last_idx = num_inputs;
-    for (auto idx : inplace_indices) {
-      int i = idx.IntValue();
-      if (i >= 0) {
-        ret.push_back(Integer(i));
-      } else {
-        ret.push_back(Integer(last_idx));
-        last_idx++;
-      }
-    }
-
-    return ret;
-  }
-
   static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
                                                  const Array<Integer>& 
output_indices) {
     size_t n = func->params.size();
@@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor {
    * \param out The vector into which to collect the params/buffers
    */
   static void CollectPrimFuncParams(const Var& relax_param,
-                                    std::vector<Variant<tir::Var, 
tir::Buffer>>* out) {
+                                    std::vector<Variant<tir::Var, 
tir::Buffer>>* out,
+                                    const tvm::runtime::Optional<tir::Buffer>& 
tir_buffer_param) {
     auto struct_info = GetStructInfo(relax_param);
 
     CHECK(!struct_info.as<TupleStructInfoNode>())
@@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor {
       const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
       ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a 
known shape.";
       DataType dtype = tensor->dtype;
-      tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, 
name_hint);
+      tir::Buffer buffer;
+      if (tir_buffer_param.defined()) {
+        buffer =
+            tir::decl_buffer(shape_expr->values, dtype, name_hint, 
tir_buffer_param.value().scope(),
+                             tir_buffer_param.value()->axis_separators);
+      } else {
+        buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
+      }
       out->push_back(std::move(buffer));
 
     } else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 90baeaad04..99e7a5d2b7 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 import tvm.testing
 from tvm import relax, topi
@@ -2314,5 +2316,131 @@ def test_private_nonprimitive_func():
     _check(Before, Before)
 
 
+def test_fuse_with_axis_separators():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def add(a: T.handle, b: T.handle, c: T.handle):
+            A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+
+            for iters in T.grid(T.int64(16), T.int64(32)):
+                with T.block("compute"):
+                    i, j = T.axis.remap("SS", iters)
+                    C[i, j] = A[i, j] + B[i, j]
+
+        @R.function(private=True)
+        def fused_function(
+            x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+        ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Before
+            with R.dataflow():
+                w = R.call_tir(
+                    cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), 
T.int64(32)], "float32")
+                )
+                out = R.call_tir(
+                    cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), 
T.int64(32)], "float32")
+                )
+                R.output(out)
+            return out
+
+        @R.function
+        def main(
+            x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+        ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused_function(x, y, z)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
+            T.func_attr({"tir.noalias": True})
+            X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
+            for iters in T.grid(*X.shape):
+                with T.block("compute_Y"):
+                    i, j = T.axis.remap("SS", iters)
+                    Temp[i, j] = X[i, j] + Y[i, j]
+
+            for iters in T.grid(*X.shape):
+                with T.block("compute_Z"):
+                    i, j = T.axis.remap("SS", iters)
+                    C[i, j] = Temp[i, j] + Z[i, j]
+
+        @R.function
+        def main(
+            x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+        ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.fused_function,
+                    [x, y, z],
+                    out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
+def test_fuse_with_axis_separators_inconsistent_buffer_mapping():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def mul(a: T.handle, b: T.handle, c: T.handle):
+            A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[])
+            C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+
+            for iters in T.grid(T.int64(16), T.int64(32)):
+                with T.block("compute"):
+                    i, j = T.axis.remap("SS", iters)
+                    C[i, j] = A[i, j] * B[i, j]
+
+        @R.function(private=True)
+        def fused_function(
+            x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+        ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Before
+            with R.dataflow():
+                out = R.call_tir(
+                    cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), 
T.int64(32)], "float32")
+                )
+                R.output(out)
+            return out
+
+        @R.function
+        def main(
+            x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+        ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused_function(x)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(
+        tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same 
relax var:.*"
+    ):
+        relax.transform.FuseTIR()(Before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to