Lunderberg commented on code in PR #16487:
URL: https://github.com/apache/tvm/pull/16487#discussion_r1471322175


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -438,9 +440,35 @@ class FusedTIRConstructor : public ExprVisitor {
     auto it = func_info_.expr2buffers.find(body);
     ICHECK(it != func_info_.expr2buffers.end())
         << "Fail to detect output buffers for function body";
+
     const Array<tir::Buffer>& buffers = (*it).second;
+
+    // map of input buffers to indices (helpful for detecting in-place inputs)
+    std::unordered_map<tir::Buffer, Integer, ObjectPtrHash, ObjectPtrEqual> 
buffer_to_idx;
+    std::unordered_map<tir::Var, Integer, ObjectPtrHash, ObjectPtrEqual> 
input_to_idx;
+    for (size_t i = 0; i < func_info_.params.size(); i++) {
+      input_to_idx[func_info_.params[i]] = Integer(i);
+    }
+    for (auto kv : func_info_.buffer_map) {

Review Comment:
   Nitpick: Using structured bindings (`for(auto [var, buffer]: 
func_info_.buffer_map)`) would make this easier to read.



##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -1930,5 +1930,251 @@ def main(
     _check(Before, After)
 
 
+def test_inplace_simple():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add_inplace(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                # this overwrites x and is actually evil but we are doing it 
just to test the pass

Review Comment:
   I like the comment calling it out.  Since this is a private function, I 
could see it being valid (and useful!) in cases where the calling scope never 
uses `x` again.  Seems like we keep circling around linear types, and whether 
transfer of ownership should be representable when calling relax functions.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -438,9 +440,35 @@ class FusedTIRConstructor : public ExprVisitor {
     auto it = func_info_.expr2buffers.find(body);
     ICHECK(it != func_info_.expr2buffers.end())
         << "Fail to detect output buffers for function body";
+
     const Array<tir::Buffer>& buffers = (*it).second;
+
+    // map of input buffers to indices (helpful for detecting in-place inputs)
+    std::unordered_map<tir::Buffer, Integer, ObjectPtrHash, ObjectPtrEqual> 
buffer_to_idx;
+    std::unordered_map<tir::Var, Integer, ObjectPtrHash, ObjectPtrEqual> 
input_to_idx;
+    for (size_t i = 0; i < func_info_.params.size(); i++) {
+      input_to_idx[func_info_.params[i]] = Integer(i);
+    }
+    for (auto kv : func_info_.buffer_map) {
+      if (input_to_idx.count(kv.first)) {
+        buffer_to_idx[kv.second] = input_to_idx[kv.first];
+      }
+    }
+
+    // numbered separately because the number of output *vars* might differ 
from the
+    // number of outputs if there are in-place inputs
+    int out_idx = 0;
     for (size_t i = 0; i < buffers.size(); ++i) {
-      tir::Var param = tir::Var("p_output" + std::to_string(i), 
PrimType(DataType::Handle()));
+      // Do not add output vars for in-place inputs
+      // (i.e., already listed in the buffer map. This would result
+      // in duplicates in the buffer map otherwise)
+      if (buffer_to_idx.count(buffers[i])) {

Review Comment:
   Nitpick: Using `.count` followed by `operator[]` requires two lookups into 
the map.  This should be avoided, by using `.find` instead.
   
   ```c++
   if (auto it = buffer_to_idx.count(buffers[i]); it != buffer_to_idx.end()) {
       inplace_indices_.push_back(it.second);
   }
   ```



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -897,8 +973,11 @@ class TIRFuseMutator : public ExprMutator {
     for (const auto& [gv, func] : mod->functions) {
       // Only fuse primitive relax functions
       if (func->IsInstance<relax::FunctionNode>() && 
func->HasNonzeroAttr(attr::kPrimitive)) {
-        tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
-        mutator.fused_tir_funcs_.Set(gv, fused_tir);
+        const auto& [prim_func, indices] = 
FusedTIRConstructor::GetFusedTIR(mod, gv);
+        mutator.fused_tir_funcs_.Set(gv, prim_func);
+        if (!indices.empty()) {

Review Comment:
   Nitpick: Using `if (indices.size())` instead of `if (!indices.empty())` 
would avoiding double-negatives and make the condition easier to read.  
(Though, personal preference as `if (indices.size())` relies on conversion of 
non-zero `size_t` to `true`.)



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -367,17 +368,18 @@ class FusedTIRConstructor : public ExprVisitor {
    * \brief Construct a fused TIR PrimFunc from a relax sub-function
    * \param mod The IRModule
    * \param gv The global var of relax subfunction to be fused into one 
PrimFunc
-   * \return The fused TIR PrimFunc
+   * \return The fused TIR PrimFunc and the in-place indices (non-empty for an 
in-place call)
    */
-  static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
+  static std::pair<tir::PrimFunc, Array<Integer>> GetFusedTIR(const IRModule& 
mod,
+                                                              const GlobalVar& 
gv) {
     FusedTIRConstructor visitor(mod, gv->name_hint);
     BaseFunc f = mod->Lookup(gv);
     CHECK(f->IsInstance<relax::FunctionNode>())
         << "Expected relax functions, but got: " << f->GetTypeKey();
     CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
         << "Expected a function with attr `kPrimitive`";
     visitor(Downcast<relax::Function>(f));
-    return visitor.fused_tir_;
+    return {visitor.fused_tir_, visitor.inplace_indices_};

Review Comment:
   Can we add a safety check that all values in `visitor.inplace_indices_` are 
unique?  Non-unique values would mean that the fused function returns two 
tensors, one of which is an in-place operation that overwrites the other.
   
   This should only be a safety check, as any failures would mean that an 
upstream pass had made an error by inserting `call_tir_inplace` in the first 
place.  However, since accidental aliasing would be rather tricky to hunt down 
at runtime, and since this would be an easy place to catch it, I think it would 
be worth adding a check.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -985,26 +1065,34 @@ class TIRFuseMutator : public ExprMutator {
             CHECK(prim_value->value.defined())
                 << "FuseTIR requires all R.Prim arguments to have a known 
value.";
             PrimExpr expr = prim_value->value.value();
-            CHECK(expr->IsInstance<tir::VarNode>())
-                << "FuseTIR currently requires all R.Prim arguments to provide 
a single tir::Var.";
+            CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently 
requires all R.Prim "
+                                                       "arguments to provide a 
single tir::Var.";
             tir_vars.push_back(expr);
 
           } else {
             arg_list.push_back(arg);
           }
         }
-        // Step b. Create call_tir
+        // Step b. Create call_tir or call_tir_inplace
         Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
         if (!tir_vars.empty()) {
           call_args.push_back(ShapeExpr(tir_vars));
         }
-        return Call(call_tir_op_, call_args, call->attrs, 
{GetStructInfo(call)});
+        Op call_op = call_tir_op_;
+        Attrs call_attrs = call->attrs;
+        if (inplace_indices_.count(old_gv)) {

Review Comment:
   Same request here for using `.count` followed by `.at`.
   
   ```c++
   if(auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) {
       ...
   }
   ```



##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -1930,5 +1930,251 @@ def main(
     _check(Before, After)
 
 
+def test_inplace_simple():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add_inplace(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                # this overwrites x and is actually evil but we are doing it 
just to test the pass

Review Comment:
   Can we add to the comment that this use of `call_tir_inplace` is 
deliberately not inserted by `relax.transform.DataflowUseInplaceCalls`?



##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -1930,5 +1930,251 @@ def main(
     _check(Before, After)
 
 
+def test_inplace_simple():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add_inplace(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])

Review Comment:
   Nitpick: Unless they're relevant to the specific test case, removing the 
`T.reads` and `T.writes` annotations can make the test case easier to read.



##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -1930,5 +1930,251 @@ def main(
     _check(Before, After)
 
 
+def test_inplace_simple():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add_inplace(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                # this overwrites x and is actually evil but we are doing it 
just to test the pass
+                lv = R.call_tir_inplace(
+                    cls.add_inplace,
+                    (x, p0),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = 
cls.fused_add_exp_squeeze(x, p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], p0[()])
+                    T.writes(x[v_ax0, v_ax1])
+                    x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1])
+                    T.writes(x[v_i0, v_i1])
+                    x[v_i0, v_i1] = T.exp(x[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1])
+                    T.writes(x[v_ax0, v_ax1])
+                    x[v_ax0, v_ax1] = x[v_ax0, v_ax1]
+
+        # note that this will clobber x! Use with caution
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace(
+                    cls.fused_add_exp_squeeze,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                    inplace_indices=[0],
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+
+def test_fuse_inplace_and_non_inplace():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            B: T.Buffer((), "float32"),
+            Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(Out[v_ax0, v_ax1])
+                    Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.add,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = 
cls.fused_add_exp_squeeze(x, p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            p0: T.Buffer((), "float32"),
+            p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], p0[()])
+                    T.writes(p_output0[v_ax0, v_ax1])
+                    p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(p_output0[v_i0, v_i1])
+                    T.writes(p_output0[v_i0, v_i1])
+                    p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(p_output0[v_ax0, v_ax1])
+                    T.writes(p_output0[v_ax0, v_ax1])
+                    p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
+                    cls.fused_add_exp_squeeze,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+

Review Comment:
   Can we add a test case in which `main` calls two different relax 
`kPrimitive` functions, both of which use the same underlying TIR functions, 
but only one of which uses  `T.call_tir_inplace`?
   
   I think it should already be handled correctly in the current 
implementation, but a lot of potential optimization (e.g. caching the result of 
`MapArgsToBuffer` by the argument StructInfo) would break it, and so it would 
good to include.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to