This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b2c58ef122 [TIR] Fix Data Type Mismatch (int64 vs int32) in
T.match_buffer when Working with Scalar Buffers in TIR (#18466)
b2c58ef122 is described below
commit b2c58ef122e48845437cfb29ccccaba945b92343
Author: Neo Chien <[email protected]>
AuthorDate: Thu Nov 20 13:22:17 2025 +0800
[TIR] Fix Data Type Mismatch (int64 vs int32) in T.match_buffer when
Working with Scalar Buffers in TIR (#18466)
This PR is trying to fix issues
https://github.com/apache/tvm/issues/17392.
The issue with `T.match_buffer` for scalar elements that was causing the
int64 vs. int32 type mismatch error in TVM.
Fix:
- Safe Type Coercion: Allows automatic casting between integer types
when they have the same number of lanes
- Type Safety Preserved: Still rejects incompatible type combinations
(int vs float, different lane counts)
---------
Co-authored-by: cchung100m <[email protected]>
---
src/tir/transforms/lower_match_buffer.cc | 11 ++++++--
.../test_tir_transform_lower_match_buffer.py | 31 ++++++++++++++++++++++
2 files changed, 40 insertions(+), 2 deletions(-)
diff --git a/src/tir/transforms/lower_match_buffer.cc
b/src/tir/transforms/lower_match_buffer.cc
index f7155b09f4..dc3cc0dbab 100644
--- a/src/tir/transforms/lower_match_buffer.cc
+++ b/src/tir/transforms/lower_match_buffer.cc
@@ -220,8 +220,15 @@ class MatchBufferLower : public StmtExprMutator {
}
void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name =
"argument") {
- CHECK_EQ(arg.dtype(), value.dtype())
- << "The data type mismatched: " << arg->dtype << " vs. " <<
value->dtype;
+ if (arg.dtype() != value.dtype()) {
+ if (arg.dtype().is_int() && value.dtype().is_int() &&
+ arg.dtype().lanes() == value.dtype().lanes()) {
+ value = cast(arg.dtype(), value);
+ } else {
+ CHECK_EQ(arg.dtype(), value.dtype())
+ << "The data type mismatched: " << arg->dtype << " vs. " <<
value->dtype;
+ }
+ }
// Handle recursive case
value = Substitute(std::move(value), var_map_);
if (arg->IsInstance<VarNode>()) {
diff --git
a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
index 410269ffae..2ba658b738 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
@@ -532,5 +532,36 @@ def test_fail_match_func_param():
_check_fail(fail_match_func_param)
[email protected]_func
+def scalar_match_buffer_type_coercion(a: T.handle) -> None:
+ A = T.match_buffer(a, (8, 8))
+ for i, j in T.grid(8, 8):
+ with T.block(""):
+ vi = T.axis.spatial(8, i)
+ vj = T.axis.spatial(8, j)
+ T.reads()
+ T.writes(A[vi, vj])
+ # Create scalar match buffer from single element - this triggers
type coercion
+ scalar_buf = T.match_buffer(A[vi, vj], (), offset_factor=1)
+ scalar_buf[()] = T.float32(1.0)
+
+
[email protected]_func
+def transformed_scalar_match_buffer_type_coercion(a: T.handle) -> None:
+ A = T.match_buffer(a, (8, 8))
+ for i, j in T.grid(8, 8):
+ with T.block(""):
+ vi = T.axis.spatial(8, i)
+ vj = T.axis.spatial(8, j)
+ T.reads()
+ T.writes(A[vi, vj])
+ # Scalar match_buffer eliminated, direct assignment
+ A[vi, vj] = T.float32(1.0)
+
+
+def test_scalar_match_buffer_type_coercion():
+ _check(scalar_match_buffer_type_coercion,
transformed_scalar_match_buffer_type_coercion)
+
+
if __name__ == "__main__":
tvm.testing.main()