This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 43f06ca42a [TIR] Avoid re-defining `var = arg_var` in ArgBinder
(#14952)
43f06ca42a is described below
commit 43f06ca42ab35c4528568720710b32f5d592eaf4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun May 28 15:44:32 2023 -0500
[TIR] Avoid re-defining `var = arg_var` in ArgBinder (#14952)
Prior to this commit, `ArgBinder` would always introduce a new
variable to represent the input argument, even if the argument already
a primitive type. This introduces trivial let bindings that are
expected to be simplified out, but which can produce dangling
`tir::Var` usage in some cases (see
https://github.com/apache/tvm/pull/14951).
This commit updates `ArgBinder` to prefer using the original
`tir::Var` when possible. That is, when a function takes `n: T.int32`
as input, the packed function should produce a binding `n: T.int32 =
T.tvm_struct_get(...)`, rather than producing a binding `arg_n =
T.tvm_struct_get(...)` followed by `n = arg_n`.
---
src/tir/transforms/make_packed_api.cc | 32 ++++++----------------
.../unittest/test_tir_transform_make_packed_api.py | 17 +++++-------
2 files changed, 16 insertions(+), 33 deletions(-)
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index 825a8da45b..062f1c0509 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -262,24 +262,13 @@ PrimFunc MakePackedAPI(PrimFunc func) {
return res;
};
- // Need to re-declare vars, in case some arguments also appears in the
buffer.
- std::vector<std::pair<Var, Var>> var_def;
+ // Need to delay binding of the buffers, in case some arguments also
+ // appear in the buffer.
+ std::vector<std::pair<PrimExpr, Var>> var_def;
std::vector<std::pair<Var, Buffer>> buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
- std::string param_name = [&]() {
- std::ostringstream oss;
- oss << "arg";
- if (param->name_hint.defined() && (!param->name_hint.empty())) {
- oss << "." << param->name_hint;
-
- } else {
- oss << i;
- }
- return oss.str();
- }();
- Var v_arg = Var(param_name, param->dtype);
// Pluck the device API context out based on name
if (param->name_hint == kDeviceContextVar) {
@@ -288,19 +277,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
continue;
}
+ var_def.emplace_back(f_arg_value(param.dtype(), i), param);
if (func_ptr->buffer_map.count(param)) {
- buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
- } else {
- var_def.emplace_back(v_arg, param);
+ buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
- // Value loads
- seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
- Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
+ Var tcode(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(
LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids,
{IntImm(DataType::Int(32), i)}), nop));
- DataType t = v_arg.dtype();
+ DataType t = param.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
@@ -330,8 +316,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let binding yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
- for (const auto& kv : var_def) {
- binder.Bind(kv.second, kv.first, name_hint + "." + kv.first->name_hint,
true);
+ for (const auto& [expr, param] : var_def) {
+ binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
for (const auto& kv : buffer_def) {
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index cd27c0305c..8af7efb596 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -101,18 +101,15 @@ def test_variable_passed_from_args():
assert func.body.condition.b == 2
# Arguments unpacking
- assignment = _find_assignment(func.body, "arg.input_buffer")
+ assignment = _find_assignment(func.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
- assignment = _find_assignment(func.body, "arg.not_device_context")
- assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
-
- assignment = _find_assignment(func.body, "input_buffer")
- assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1,
"handle")'
+ assignment = _find_assignment(assignment.body, "input_buffer")
+ assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1,
"handle")'
unpacked_input_buffer = assignment.var
assignment = _find_assignment(func.body, "not_device_context")
- assert str(assignment.value) == "arg_not_device_context"
+ assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
unpacked_not_device_context = assignment.var
seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
@@ -147,11 +144,11 @@ def test_device_api_context_implicit_resource_handle():
assert func.body.condition.b == 1
# Arguments unpacking
- assignment = _find_assignment(func.body, "arg.input_buffer")
+ assignment = _find_assignment(func.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
- assignment = _find_assignment(func.body, "input_buffer")
- assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1,
"handle")'
+ assignment = _find_assignment(assignment.body, "input_buffer")
+ assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1,
"handle")'
unpacked_input_buffer = assignment.var
seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)