This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b2c58ef122 [TIR] Fix Data Type Mismatch (int64 vs int32) in 
T.match_buffer when Working with Scalar Buffers in TIR (#18466)
b2c58ef122 is described below

commit b2c58ef122e48845437cfb29ccccaba945b92343
Author: Neo Chien <[email protected]>
AuthorDate: Thu Nov 20 13:22:17 2025 +0800

    [TIR] Fix Data Type Mismatch (int64 vs int32) in T.match_buffer when 
Working with Scalar Buffers in TIR (#18466)
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/17392.
    
    The issue with `T.match_buffer` for scalar elements that was causing the
    int64 vs. int32 type mismatch error in TVM.
    
    Fix:
    - Safe Type Coercion: Allows automatic casting between integer types
    when they have the same number of lanes
    - Type Safety Preserved: Still rejects incompatible type combinations
    (int vs float, different lane counts)
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 src/tir/transforms/lower_match_buffer.cc           | 11 ++++++--
 .../test_tir_transform_lower_match_buffer.py       | 31 ++++++++++++++++++++++
 2 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/src/tir/transforms/lower_match_buffer.cc 
b/src/tir/transforms/lower_match_buffer.cc
index f7155b09f4..dc3cc0dbab 100644
--- a/src/tir/transforms/lower_match_buffer.cc
+++ b/src/tir/transforms/lower_match_buffer.cc
@@ -220,8 +220,15 @@ class MatchBufferLower : public StmtExprMutator {
   }
 
   void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = 
"argument") {
-    CHECK_EQ(arg.dtype(), value.dtype())
-        << "The data type mismatched: " << arg->dtype << " vs. " << 
value->dtype;
+    if (arg.dtype() != value.dtype()) {
+      if (arg.dtype().is_int() && value.dtype().is_int() &&
+          arg.dtype().lanes() == value.dtype().lanes()) {
+        value = cast(arg.dtype(), value);
+      } else {
+        CHECK_EQ(arg.dtype(), value.dtype())
+            << "The data type mismatched: " << arg->dtype << " vs. " << 
value->dtype;
+      }
+    }
     // Handle recursive case
     value = Substitute(std::move(value), var_map_);
     if (arg->IsInstance<VarNode>()) {
diff --git 
a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py 
b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
index 410269ffae..2ba658b738 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py
@@ -532,5 +532,36 @@ def test_fail_match_func_param():
     _check_fail(fail_match_func_param)
 
 
[email protected]_func
+def scalar_match_buffer_type_coercion(a: T.handle) -> None:
+    A = T.match_buffer(a, (8, 8))
+    for i, j in T.grid(8, 8):
+        with T.block(""):
+            vi = T.axis.spatial(8, i)
+            vj = T.axis.spatial(8, j)
+            T.reads()
+            T.writes(A[vi, vj])
+            # Create scalar match buffer from single element - this triggers 
type coercion
+            scalar_buf = T.match_buffer(A[vi, vj], (), offset_factor=1)
+            scalar_buf[()] = T.float32(1.0)
+
+
[email protected]_func
+def transformed_scalar_match_buffer_type_coercion(a: T.handle) -> None:
+    A = T.match_buffer(a, (8, 8))
+    for i, j in T.grid(8, 8):
+        with T.block(""):
+            vi = T.axis.spatial(8, i)
+            vj = T.axis.spatial(8, j)
+            T.reads()
+            T.writes(A[vi, vj])
+            # Scalar match_buffer eliminated, direct assignment
+            A[vi, vj] = T.float32(1.0)
+
+
+def test_scalar_match_buffer_type_coercion():
+    _check(scalar_match_buffer_type_coercion, 
transformed_scalar_match_buffer_type_coercion)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to