This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a3fadc065 [Unity][Transform] Handle `call_tir_inplace` in `FuseTIR` 
and `FuseOps` (#16487)
6a3fadc065 is described below

commit 6a3fadc0654ecf9557ffe08d24677684c96e80b0
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 6 15:38:29 2024 -0500

    [Unity][Transform] Handle `call_tir_inplace` in `FuseTIR` and `FuseOps` 
(#16487)
    
    * WIP initial commit
    
    * Handle in-place calls in FuseTIR
    
    * Formatting
    
    * Add test case for FuseOps
    
    * Address review comments related to clarity
    
    * Use a set to ensure in-place indices will be unique
    
    * Add test case where PrimFunc is used both in-place and DPS
    
    * Explicitly check for duplicate index
---
 src/relax/transform/fuse_ops.cc               |  10 +-
 src/relax/transform/fuse_tir.cc               | 158 ++++++++++---
 tests/python/relax/test_transform_fuse_ops.py | 141 +++++++++++
 tests/python/relax/test_transform_fuse_tir.py | 324 ++++++++++++++++++++++++++
 4 files changed, 600 insertions(+), 33 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index b0eeba399e..32780f6dd2 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -183,6 +183,8 @@ class GraphCreator : public ExprVisitor {
     ICHECK_NOTNULL(binding_var_node);
 
     static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+
     OpPatternKind pattern = OpPatternKind::kOpaque;
     Array<Expr> args = call->args;
 
@@ -191,7 +193,7 @@ class GraphCreator : public ExprVisitor {
     // - Otherwise, the pattern of the current binding variable node is set to 
`kOpaque`, and we
     // recurse into the call expression.
     const auto* op = call->op.as<OpNode>();
-    if (op == call_tir_op_.get()) {
+    if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) {
       const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
       tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
 
@@ -377,7 +379,8 @@ class FunctionCreator : public ExprMutator {
    * function accordingly
    * \param binding The binding to be appended
    * \note Allowed bindings are:
-   *  - VarBinding with value being a call node calling `relax.call_tir`.
+   *  - VarBinding with value being a call node calling `relax.call_tir` or
+   *    `relax.call_tir_inplace`.
    *  - VarBinding with value being a tuple-get-item node.
    * // TODO(tvm-team): handle match shape
    */
@@ -387,7 +390,8 @@ class FunctionCreator : public ExprMutator {
 
     if (const auto* var_binding = binding.as<VarBindingNode>()) {
       if (const auto* call = var_binding->value.as<CallNode>()) {
-        if (call->op == Op::Get("relax.call_tir")) {
+        if (call->op == Op::Get("relax.call_tir") ||
+            call->op == Op::Get("relax.call_tir_inplace")) {
           // Update the name of the function.
           name_hint_ = name_hint_ + "_" + 
Downcast<GlobalVar>(call->args[0])->name_hint;
 
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 1c25229d88..4ad291e91c 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 #include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
@@ -367,9 +368,10 @@ class FusedTIRConstructor : public ExprVisitor {
    * \brief Construct a fused TIR PrimFunc from a relax sub-function
    * \param mod The IRModule
    * \param gv The global var of relax subfunction to be fused into one 
PrimFunc
-   * \return The fused TIR PrimFunc
+   * \return The fused TIR PrimFunc and the in-place indices (non-empty for an 
in-place call)
    */
-  static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
+  static std::pair<tir::PrimFunc, Array<Integer>> GetFusedTIR(const IRModule& 
mod,
+                                                              const GlobalVar& 
gv) {
     FusedTIRConstructor visitor(mod, gv->name_hint);
     BaseFunc f = mod->Lookup(gv);
     CHECK(f->IsInstance<relax::FunctionNode>())
@@ -377,7 +379,11 @@ class FusedTIRConstructor : public ExprVisitor {
     CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
         << "Expected a function with attr `kPrimitive`";
     visitor(Downcast<relax::Function>(f));
-    return visitor.fused_tir_;
+    Array<Integer> inplace_indices;
+    for (size_t idx : visitor.inplace_indices_) {
+      inplace_indices.push_back(Integer(idx));
+    }
+    return {visitor.fused_tir_, inplace_indices};
   }
 
  private:
@@ -438,9 +444,38 @@ class FusedTIRConstructor : public ExprVisitor {
     auto it = func_info_.expr2buffers.find(body);
     ICHECK(it != func_info_.expr2buffers.end())
         << "Fail to detect output buffers for function body";
+
     const Array<tir::Buffer>& buffers = (*it).second;
+
+    // map of input buffers to indices (helpful for detecting in-place inputs)
+    std::unordered_map<tir::Buffer, size_t, ObjectPtrHash, ObjectPtrEqual> 
buffer_to_idx;
+    std::unordered_map<tir::Var, size_t, ObjectPtrHash, ObjectPtrEqual> 
input_to_idx;
+    for (size_t i = 0; i < func_info_.params.size(); i++) {
+      input_to_idx[func_info_.params[i]] = i;
+    }
+    for (auto [var, buffer] : func_info_.buffer_map) {
+      if (auto it = input_to_idx.find(var); it != input_to_idx.end()) {
+        buffer_to_idx[buffer] = (*it).second;
+      }
+    }
+
+    // numbered separately because the number of output *vars* might differ 
from the
+    // number of outputs if there are in-place inputs
+    int out_idx = 0;
     for (size_t i = 0; i < buffers.size(); ++i) {
-      tir::Var param = tir::Var("p_output" + std::to_string(i), 
PrimType(DataType::Handle()));
+      // Do not add output vars for in-place inputs
+      // (i.e., already listed in the buffer map. This would result
+      // in duplicates in the buffer map otherwise)
+      if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) 
{
+        auto idx = (*it).second;
+        CHECK(!inplace_indices_.count(idx))
+            << "In-place index " << idx << " used twice! An argument must be 
aliased.";
+        inplace_indices_.insert(idx);
+        continue;
+      }
+
+      tir::Var param = tir::Var("p_output" + std::to_string(out_idx), 
PrimType(DataType::Handle()));
+      out_idx++;
       func_info_.buffer_map.Set(param, buffers[i]);
       func_info_.params.push_back(param);
       func_info_.output_buffers.insert(buffers[i].get());
@@ -476,8 +511,11 @@ class FusedTIRConstructor : public ExprVisitor {
   void VisitExpr_(const CallNode* call) final {
     ExprVisitor::VisitExpr_(call);
     static const Op& call_tir_op_ = Op::Get("relax.call_tir");
-    ICHECK(call->op == call_tir_op_)
-        << "Only call_tir is supported in primitive function, but got: " << 
GetRef<Expr>(call);
+    static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+
+    ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+        << "Only call_tir and call_tir_inplace are supported in primitive 
function, but got: "
+        << GetRef<Expr>(call);
 
     // Step 1. Get Global var and PrimFunc
     GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
@@ -503,7 +541,7 @@ class FusedTIRConstructor : public ExprVisitor {
     MapInputBuffer(prim_func, call->args[1]);
     const Array<Array<PrimExpr>>& output_buffer_shapes = 
GetCallTIROutputShapes(call);
 
-    AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, 
output_buffer_shapes);
+    AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes);
 
     // Step 6. Update tir_vars
     if (call->args.size() > 2) {
@@ -566,7 +604,8 @@ class FusedTIRConstructor : public ExprVisitor {
    */
   static Array<Array<PrimExpr>> GetCallTIROutputShapes(const CallNode* call) {
     static const Op& call_tir_op_ = Op::Get("relax.call_tir");
-    ICHECK(call->op.same_as(call_tir_op_));
+    static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+    ICHECK(call->op.same_as(call_tir_op_) || 
call->op.same_as(call_tir_inplace_op_));
     ICHECK_EQ(call->sinfo_args.size(), 1);
     auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
       const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
@@ -611,7 +650,7 @@ class FusedTIRConstructor : public ExprVisitor {
         }
       }
     }
-    // Make sure every buffers are mapped.
+    // Make sure every buffer is mapped.
     ICHECK_EQ(buffer_idx, buffers.size());
   }
 
@@ -639,28 +678,49 @@ class FusedTIRConstructor : public ExprVisitor {
     MapArgsToBuffer(arg_list, buffer_list);
   }
 
-  static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func, 
size_t output_size) {
+  static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                                int num_inputs) {
+    Array<Integer> ret;
+    int last_idx = num_inputs;
+    for (auto idx : inplace_indices) {
+      int i = idx.IntValue();
+      if (i >= 0) {
+        ret.push_back(Integer(i));
+      } else {
+        ret.push_back(Integer(last_idx));
+        last_idx++;
+      }
+    }
+
+    return ret;
+  }
+
+  static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
+                                                 const Array<Integer>& 
output_indices) {
     size_t n = func->params.size();
     int symbolic_var_index = -1;
+    size_t output_size = output_indices.size();
     ICHECK_GE(n, output_size);
-    for (size_t i = 0; i < n; ++i) {
-      const tir::Var& param = func->params[i];
+
+    Array<tir::Var> ret;
+    for (auto idx : output_indices) {
+      int i = idx.IntValue();
+      const tir::Var& param = func->params[static_cast<size_t>(i)];
       if (param->dtype.is_int() || param->dtype.is_uint()) {
         if (symbolic_var_index == -1) symbolic_var_index = i;
       } else if (param->dtype.is_handle()) {
         CHECK(symbolic_var_index == -1) << "The scalar input should be at the 
ending of the "
                                            "parameter list.";
+        ret.push_back(param);
       } else {
         LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle 
or scalar, but got: "
                    << param->dtype;
       }
     }
+
     size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index;
     ICHECK_GE(end_index, output_size);
-    size_t begin_index = end_index - output_size;
-    Array<tir::Var> output_params{func->params.begin() + begin_index,
-                                  func->params.begin() + end_index};
-    return output_params;
+    return ret;
   }
 
   /*!
@@ -670,18 +730,39 @@ class FusedTIRConstructor : public ExprVisitor {
    * \param func The old TIR PrimFunc
    * \param output_shapes The shape of output params.
    */
-  void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
+  void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& 
func,
                                   const Array<Array<PrimExpr>>& output_shapes) 
{
+    bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace"));
+
     size_t n = func->params.size();
+    int num_inputs = Downcast<Tuple>(call->args[1])->fields.size();
     size_t output_size = output_shapes.size();
     ICHECK_GE(n, output_size);
-    // Allocate intermediate buffer
-    Array<tir::Buffer> alloc_buffers;
-    Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_size);
+    Array<tir::Buffer> output_buffers;
+    Array<Integer> output_idxs;
+    if (is_inplace) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs));
+    } else {
+      for (size_t i = 0; i < output_size; i++) {
+        output_idxs.push_back(num_inputs + i);
+      }
+    }
+
+    Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_idxs);
+    auto input_buffers = func_info_.expr2buffers.Get(call->args[1]);
     for (size_t i = 0; i < output_size; ++i) {
       const tir::Var& param = output_params[i];
       const tir::Buffer& buffer = func->buffer_map.at(param);
 
+      // if this is an inplace output, do not do an intermediate allocation
+      if (output_idxs[i].IntValue() < num_inputs) {
+        CHECK(input_buffers.defined()) << "Inplace functions must have some 
defined input";
+        
output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]);
+        continue;
+      }
+
       auto unify_name_hints = [this, &buffer]() {
         String base_name = buffer->name;
         String unique_name = base_name + "_intermediate";
@@ -703,14 +784,14 @@ class FusedTIRConstructor : public ExprVisitor {
       n->name = unify_name_hints();
       tir::Buffer new_buffer(n);
       func_info_.alloc_buffers.push_back(new_buffer);
-      alloc_buffers.push_back(new_buffer);
+      output_buffers.push_back(new_buffer);
 
       // Match the shape of the output buffer with the shape
       func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape);
       func_info_.buffer_subst_map.Set(buffer, new_buffer);
     }
     // Update expr2buffers
-    func_info_.expr2buffers.Set(expr, alloc_buffers);
+    func_info_.expr2buffers.Set(GetRef<Expr>(call), output_buffers);
   }
 
   /*!
@@ -858,6 +939,8 @@ class FusedTIRConstructor : public ExprVisitor {
   FuseFuncInfo func_info_;
   /*! \brief The tir function after fusion*/
   tir::PrimFunc fused_tir_;
+  /*! \brief Indices of inputs that are used for in-place computation */
+  std::unordered_set<size_t> inplace_indices_;
 };
 
 std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const 
Var& tuple_var) {
@@ -897,8 +980,11 @@ class TIRFuseMutator : public ExprMutator {
     for (const auto& [gv, func] : mod->functions) {
       // Only fuse primitive relax functions
       if (func->IsInstance<relax::FunctionNode>() && 
func->HasNonzeroAttr(attr::kPrimitive)) {
-        tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
-        mutator.fused_tir_funcs_.Set(gv, fused_tir);
+        const auto& [prim_func, indices] = 
FusedTIRConstructor::GetFusedTIR(mod, gv);
+        mutator.fused_tir_funcs_.Set(gv, prim_func);
+        if (!indices.empty()) {
+          mutator.inplace_indices_.Set(gv, indices);
+        }
       }
     }
 
@@ -945,6 +1031,7 @@ class TIRFuseMutator : public ExprMutator {
 
   Expr VisitExpr_(const CallNode* op) final {
     static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
 
     Call call = 
Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));
 
@@ -985,26 +1072,34 @@ class TIRFuseMutator : public ExprMutator {
             CHECK(prim_value->value.defined())
                 << "FuseTIR requires all R.Prim arguments to have a known 
value.";
             PrimExpr expr = prim_value->value.value();
-            CHECK(expr->IsInstance<tir::VarNode>())
-                << "FuseTIR currently requires all R.Prim arguments to provide 
a single tir::Var.";
+            CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently 
requires all R.Prim "
+                                                       "arguments to provide a 
single tir::Var.";
             tir_vars.push_back(expr);
 
           } else {
             arg_list.push_back(arg);
           }
         }
-        // Step b. Create call_tir
+        // Step b. Create call_tir or call_tir_inplace
         Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
         if (!tir_vars.empty()) {
           call_args.push_back(ShapeExpr(tir_vars));
         }
-        return Call(call_tir_op_, call_args, call->attrs, 
{GetStructInfo(call)});
+        Op call_op = call_tir_op_;
+        Attrs call_attrs = call->attrs;
+        if (auto it = inplace_indices_.find(old_gv); it != 
inplace_indices_.end()) {
+          call_op = call_tir_inplace_op_;
+          auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
+          inplace_attrs->inplace_indices = (*it).second;
+          call_attrs = Attrs(inplace_attrs);
+        }
+        return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
       } else {
         // Case 1.2. The callee function is not primitive, nothing to do.
         return call;
       }
-    } else if (call->op == call_tir_op_) {
-      // Case 2. It is a call_tir, re-emit the PrimFunc.
+    } else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
+      // Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
       if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
         tir::PrimFunc func = 
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
         GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
@@ -1023,6 +1118,9 @@ class TIRFuseMutator : public ExprMutator {
   const IRModule& mod_;
   /*! \brief The map from global var of primitive relax function to generated 
prim func. */
   Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
+  /*! \brief The map from global var of primitive relax function to in-place 
indices
+   *  (if there are any). */
+  Map<GlobalVar, Array<Integer>> inplace_indices_;
 };
 
 IRModule FuseTIR(IRModule mod) {
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 1a4a630e3e..3cd608d8ee 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1501,5 +1501,146 @@ def test_partially_used_tuple_param():
     _check(Module, Expected)
 
 
+def test_call_tir_inplace():
+    @I.ir_module
+    class Module:
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            B: T.Buffer((), "float32"),
+            Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(Out[v_ax0, v_ax1])
+                    Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.add,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            B: T.Buffer((), "float32"),
+            Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(Out[v_ax0, v_ax1])
+                    Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_inplace_squeeze_inplace(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Expected
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.add,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor(
+                    (10, 20), dtype="float32"
+                ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0)
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 143670c701..c0a6f4448b 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1930,5 +1930,329 @@ def test_gather():
     _check(Before, After)
 
 
+def test_inplace_simple():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add_inplace(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    # T.reads(A[v_ax0, v_ax1], B[()])
+                    # T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    # T.reads(A[v_i0, v_i1])
+                    # T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    # T.reads(A[v_ax0, v_ax1])
+                    # T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                # This overwrites x and is actually evil because the function 
is marked as pure
+                # but we are doing it just to test the pass. The automatic 
DataflowUseInplaceCalls
+                # transformation will not produce code like this, but it may 
make sense to do it
+                # if ownership of x is fully and truly transferred.
+                # Users should apply with caution!
+                lv = R.call_tir_inplace(
+                    cls.add_inplace,
+                    (x, p0),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = 
cls.fused_add_exp_squeeze(x, p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    x[v_i0, v_i1] = T.exp(x[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    x[v_ax0, v_ax1] = x[v_ax0, v_ax1]
+
+        # note that this will clobber x! Use with caution
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace(
+                    cls.fused_add_exp_squeeze,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                    inplace_indices=[0],
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+
+def test_fuse_inplace_and_non_inplace():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            B: T.Buffer((), "float32"),
+            Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.add,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = 
cls.fused_add_exp_squeeze(x, p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            p0: T.Buffer((), "float32"),
+            p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
+                    cls.fused_add_exp_squeeze,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+
+def test_use_as_inplace_and_dps():
+    @I.ir_module
+    class Module:
+        # we will use it both in-place and normally (DPS)
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            B: T.Buffer((), "float32"),
+            Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @R.function(private=True)
+        def fused_sums(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.add,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.add,
+                    (x, p0, lv),
+                    inplace_indices=[2],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv2 = R.call_tir_inplace(
+                    cls.add,
+                    (x, p0, lv1),
+                    inplace_indices=[2],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(lv2)
+            return lv2
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_sums(x, 
p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def fused_sums(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            p0: T.Buffer((), "float32"),
+            p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
+                    cls.fused_sums,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to