This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a90fb8e2d9 [TIR][NarrowDataType] Bufferload's index should not inherit
bits constraint of value (#17411)
a90fb8e2d9 is described below
commit a90fb8e2d93215bdae2fbd2359374ebe914bee45
Author: wrongtest <[email protected]>
AuthorDate: Wed Sep 25 10:18:59 2024 +0800
[TIR][NarrowDataType] Bufferload's index should not inherit bits constraint
of value (#17411)
bufferload's index dtype narrowing should not inherit value bits constraint
Co-authored-by: wrongtest <[email protected]>
---
src/tir/transforms/narrow_datatype.cc | 14 +++++++++++++-
.../tir-transform/test_tir_transform_narrow_datatype.py | 17 +++++++++++++++++
2 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/narrow_datatype.cc
b/src/tir/transforms/narrow_datatype.cc
index 7b6187af64..696eae201f 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -97,6 +97,13 @@ class DataTypeVisitor final : public StmtExprVisitor {
}
}
+ void VisitExpr_(const BufferLoadNode* op) {
+ int tmp = bits_;
+ bits_ = target_bits_;
+ StmtExprVisitor::VisitExpr_(op);
+ bits_ = tmp;
+ }
+
void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
@@ -245,7 +252,12 @@ class NarrowDataTypeRewriter : public
IndexDataTypeRewriter {
const CastNode* new_op = e.as<CastNode>();
ICHECK(new_op != nullptr) << "Expected type to be CastNode"
<< ", but get " << e->GetTypeKey();
- return Cast(visitor_.vmap[op], new_op->value);
+ PrimExpr new_value = new_op->value;
+ DataType cast_type = visitor_.vmap[op];
+ if (new_value.dtype() != cast_type) {
+ new_value = Cast(cast_type, new_value);
+ }
+ return new_value;
}
return Parent::VisitExpr_(op);
}
diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
index c03dd7a529..cf85f2e371 100644
--- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
+++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
@@ -413,5 +413,22 @@ def test_avg_pool2d():
tvm.ir.assert_structural_equal(after["main"],
expected_after.with_attr("global_symbol", "main"))
+def test_narrow_i64_valued_bufferload_index_to_i32():
+ @T.prim_func
+ def before(A: T.Buffer((16,), "int64")):
+ for i in range(T.int64(15)):
+ A[i + T.int64(1)] = A[i] + T.int64(1)
+
+ @T.prim_func
+ def expect(A: T.Buffer((16,), "int64")):
+ for i in range(15):
+ A[i + 1] = A[i] + T.int64(1)
+
+ after = tvm.tir.transform.NarrowDataType(32)(
+ tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
+ )["main"]
+ tvm.ir.assert_structural_equal(after, expect.with_attr("global_symbol",
"main"))
+
+
if __name__ == "__main__":
tvm.testing.main()