This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a90fb8e2d9 [TIR][NarrowDataType] Bufferload's index should not inherit 
bits constraint of value (#17411)
a90fb8e2d9 is described below

commit a90fb8e2d93215bdae2fbd2359374ebe914bee45
Author: wrongtest <[email protected]>
AuthorDate: Wed Sep 25 10:18:59 2024 +0800

    [TIR][NarrowDataType] Bufferload's index should not inherit bits constraint 
of value (#17411)
    
    bufferload's index dtype narrowing should not inherit value bits constraint
    
    Co-authored-by: wrongtest <[email protected]>
---
 src/tir/transforms/narrow_datatype.cc                   | 14 +++++++++++++-
 .../tir-transform/test_tir_transform_narrow_datatype.py | 17 +++++++++++++++++
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/narrow_datatype.cc 
b/src/tir/transforms/narrow_datatype.cc
index 7b6187af64..696eae201f 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -97,6 +97,13 @@ class DataTypeVisitor final : public StmtExprVisitor {
     }
   }
 
+  void VisitExpr_(const BufferLoadNode* op) {
+    int tmp = bits_;
+    bits_ = target_bits_;
+    StmtExprVisitor::VisitExpr_(op);
+    bits_ = tmp;
+  }
+
   void VisitStmt_(const ForNode* op) {
     analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
     vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
@@ -245,7 +252,12 @@ class NarrowDataTypeRewriter : public 
IndexDataTypeRewriter {
       const CastNode* new_op = e.as<CastNode>();
       ICHECK(new_op != nullptr) << "Expected type to be CastNode"
                                 << ", but get " << e->GetTypeKey();
-      return Cast(visitor_.vmap[op], new_op->value);
+      PrimExpr new_value = new_op->value;
+      DataType cast_type = visitor_.vmap[op];
+      if (new_value.dtype() != cast_type) {
+        new_value = Cast(cast_type, new_value);
+      }
+      return new_value;
     }
     return Parent::VisitExpr_(op);
   }
diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py 
b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
index c03dd7a529..cf85f2e371 100644
--- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
+++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
@@ -413,5 +413,22 @@ def test_avg_pool2d():
     tvm.ir.assert_structural_equal(after["main"], 
expected_after.with_attr("global_symbol", "main"))
 
 
+def test_narrow_i64_valued_bufferload_index_to_i32():
+    @T.prim_func
+    def before(A: T.Buffer((16,), "int64")):
+        for i in range(T.int64(15)):
+            A[i + T.int64(1)] = A[i] + T.int64(1)
+
+    @T.prim_func
+    def expect(A: T.Buffer((16,), "int64")):
+        for i in range(15):
+            A[i + 1] = A[i] + T.int64(1)
+
+    after = tvm.tir.transform.NarrowDataType(32)(
+        tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
+    )["main"]
+    tvm.ir.assert_structural_equal(after, expect.with_attr("global_symbol", 
"main"))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to