This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 43f06ca42a [TIR] Avoid re-defining `var = arg_var` in ArgBinder 
(#14952)
43f06ca42a is described below

commit 43f06ca42ab35c4528568720710b32f5d592eaf4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun May 28 15:44:32 2023 -0500

    [TIR] Avoid re-defining `var = arg_var` in ArgBinder (#14952)
    
    Prior to this commit, `ArgBinder` would always introduce a new
    variable to represent the input argument, even if the argument already
    a primitive type.  This introduces trivial let bindings that are
    expected to be simplified out, but which can produce dangling
    `tir::Var` usage in some cases (see
    https://github.com/apache/tvm/pull/14951).
    
    This commit updates `ArgBinder` to prefer using the original
    `tir::Var` when possible.  That is, when a function takes `n: T.int32`
    as input, the packed function should produce a binding `n: T.int32 =
    T.tvm_struct_get(...)`, rather than producing a binding `arg_n =
    T.tvm_struct_get(...)` followed by `n = arg_n`.
---
 src/tir/transforms/make_packed_api.cc              | 32 ++++++----------------
 .../unittest/test_tir_transform_make_packed_api.py | 17 +++++-------
 2 files changed, 16 insertions(+), 33 deletions(-)

diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index 825a8da45b..062f1c0509 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -262,24 +262,13 @@ PrimFunc MakePackedAPI(PrimFunc func) {
     return res;
   };
 
-  // Need to re-declare vars, in case some arguments also appears in the 
buffer.
-  std::vector<std::pair<Var, Var>> var_def;
+  // Need to delay binding of the buffers, in case some arguments also
+  // appear in the buffer.
+  std::vector<std::pair<PrimExpr, Var>> var_def;
   std::vector<std::pair<Var, Buffer>> buffer_def;
 
   for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
     Var param = func_ptr->params[i];
-    std::string param_name = [&]() {
-      std::ostringstream oss;
-      oss << "arg";
-      if (param->name_hint.defined() && (!param->name_hint.empty())) {
-        oss << "." << param->name_hint;
-
-      } else {
-        oss << i;
-      }
-      return oss.str();
-    }();
-    Var v_arg = Var(param_name, param->dtype);
 
     // Pluck the device API context out based on name
     if (param->name_hint == kDeviceContextVar) {
@@ -288,19 +277,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
       continue;
     }
 
+    var_def.emplace_back(f_arg_value(param.dtype(), i), param);
     if (func_ptr->buffer_map.count(param)) {
-      buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
-    } else {
-      var_def.emplace_back(v_arg, param);
+      buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
     }
 
-    // Value loads
-    seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
     // type code checks
-    Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
+    Var tcode(param->name_hint + ".code", DataType::Int(32));
     seq_init.emplace_back(
         LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, 
{IntImm(DataType::Int(32), i)}), nop));
-    DataType t = v_arg.dtype();
+    DataType t = param.dtype();
     if (t.is_handle()) {
       std::ostringstream msg;
       msg << name_hint << ": Expect arg[" << i << "] to be pointer";
@@ -330,8 +316,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   // either 0 or the original stride will be correctly used. Checks here have
   // to use the args that may have no let binding yet. Therefore, hoisting let
   // binding for args before buffer declaration is needed.
-  for (const auto& kv : var_def) {
-    binder.Bind(kv.second, kv.first, name_hint + "." + kv.first->name_hint, 
true);
+  for (const auto& [expr, param] : var_def) {
+    binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
   }
 
   for (const auto& kv : buffer_def) {
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py 
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index cd27c0305c..8af7efb596 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -101,18 +101,15 @@ def test_variable_passed_from_args():
     assert func.body.condition.b == 2
 
     # Arguments unpacking
-    assignment = _find_assignment(func.body, "arg.input_buffer")
+    assignment = _find_assignment(func.body, "input_buffer")
     assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
 
-    assignment = _find_assignment(func.body, "arg.not_device_context")
-    assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
-
-    assignment = _find_assignment(func.body, "input_buffer")
-    assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, 
"handle")'
+    assignment = _find_assignment(assignment.body, "input_buffer")
+    assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, 
"handle")'
     unpacked_input_buffer = assignment.var
 
     assignment = _find_assignment(func.body, "not_device_context")
-    assert str(assignment.value) == "arg_not_device_context"
+    assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
     unpacked_not_device_context = assignment.var
 
     seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
@@ -147,11 +144,11 @@ def test_device_api_context_implicit_resource_handle():
     assert func.body.condition.b == 1
 
     # Arguments unpacking
-    assignment = _find_assignment(func.body, "arg.input_buffer")
+    assignment = _find_assignment(func.body, "input_buffer")
     assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
 
-    assignment = _find_assignment(func.body, "input_buffer")
-    assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, 
"handle")'
+    assignment = _find_assignment(assignment.body, "input_buffer")
+    assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, 
"handle")'
     unpacked_input_buffer = assignment.var
 
     seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)

Reply via email to