This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7f6da09052 [TIR] Fix Datatype in Lower TVM Builtin (#14347)
7f6da09052 is described below
commit 7f6da090528b683bb285fcc657a46e45501ca48c
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Mar 20 22:04:14 2023 -0700
[TIR] Fix Datatype in Lower TVM Builtin (#14347)
Fix data type and add minimal reproducible test.
Co-authored-by: Sunghyun Park <[email protected]>
---
src/tir/transforms/lower_tvm_builtin.cc | 17 +++++++++--------
.../unittest/test_tir_transform_lower_tvm_builtin.py | 19 ++++++++++++++++++-
2 files changed, 27 insertions(+), 9 deletions(-)
diff --git a/src/tir/transforms/lower_tvm_builtin.cc
b/src/tir/transforms/lower_tvm_builtin.cc
index 49023a5ad0..d8df2cc55a 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -239,8 +239,9 @@ class BuiltinLower : public StmtExprMutator {
}
}
}
- PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
+ PrimExpr total_bytes = make_const(DataType::UInt(64), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
+ // set total_bytes to uint64 to avoid overflow
total_bytes = total_bytes * op->extents[i];
}
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
@@ -250,13 +251,13 @@ class BuiltinLower : public StmtExprMutator {
Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1),
builtin::isnullptr(), {op->buffer_var}),
throw_last_error),
op->body});
- Stmt alloca = LetStmt(
- op->buffer_var,
- Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
- {cast(DataType::Int(32), device_type_), cast(DataType::Int(32),
device_id_),
- cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32),
op->dtype.code()),
- IntImm(DataType::Int(32), op->dtype.bits())}),
- body);
+ Stmt alloca =
+ LetStmt(op->buffer_var,
+ Call(op->buffer_var.dtype(),
Op::Get("tir.TVMBackendAllocWorkspace"),
+ {cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_),
+ total_bytes, IntImm(DataType::Int(32), op->dtype.code()),
+ IntImm(DataType::Int(32), op->dtype.bits())}),
+ body);
PrimExpr free_op = Call(DataType::Int(32),
Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_),
diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
index 76d6bb82cc..d224a688d2 100644
--- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
@@ -16,8 +16,8 @@
# under the License.
import tvm
from tvm import te
+from tvm.script import tir as T
import numpy as np
-from tvm import testing
@tvm.register_func("tvm.test_matmul")
@@ -172,6 +172,23 @@ def test_call_packed_return_non_i32():
tvm.testing.assert_allclose(a.numpy(), expected_value)
+def test_lower_overflow_int32():
+ @T.prim_func
+ def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32),
T.int64(25690112)), "float32")):
+ T.func_attr({"global_symbol": "variance4", "tir.noalias": True})
+ rxplaceholder_red = T.allocate([32], "float32", "global")
+ T_subtract = T.allocate([822083584], "float32", "global")
+ rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
+ rxplaceholder_1 = T.Buffer((T.int64(822083584),),
data=rxplaceholder.data)
+ T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
+ for ax1, ax2 in T.grid(32, 25690112):
+ cse_var_1: T.int32 = ax1 * 25690112 + ax2
+ T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] -
rxplaceholder_red_1[ax1]
+
+ func = variance4
+ tvm.build(func, target="llvm") # should not crash
+
+
if __name__ == "__main__":
test_call_packed_return_non_i32()
test_lower_packed_func()