This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f6da09052 [TIR] Fix Datatype in Lower TVM Builtin (#14347)
7f6da09052 is described below

commit 7f6da090528b683bb285fcc657a46e45501ca48c
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Mar 20 22:04:14 2023 -0700

    [TIR] Fix Datatype in Lower TVM Builtin (#14347)
    
    Fix data type and add minimal reproducible test.
    
    Co-authored-by: Sunghyun Park <[email protected]>
---
 src/tir/transforms/lower_tvm_builtin.cc               | 17 +++++++++--------
 .../unittest/test_tir_transform_lower_tvm_builtin.py  | 19 ++++++++++++++++++-
 2 files changed, 27 insertions(+), 9 deletions(-)

diff --git a/src/tir/transforms/lower_tvm_builtin.cc 
b/src/tir/transforms/lower_tvm_builtin.cc
index 49023a5ad0..d8df2cc55a 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -239,8 +239,9 @@ class BuiltinLower : public StmtExprMutator {
         }
       }
     }
-    PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
+    PrimExpr total_bytes = make_const(DataType::UInt(64), nbytes);
     for (size_t i = 0; i < op->extents.size(); ++i) {
+      // set total_bytes to uint64 to avoid overflow
       total_bytes = total_bytes * op->extents[i];
     }
     ICHECK(device_type_.defined()) << "Unknown device type in current IR";
@@ -250,13 +251,13 @@ class BuiltinLower : public StmtExprMutator {
     Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), 
builtin::isnullptr(), {op->buffer_var}),
                                     throw_last_error),
                          op->body});
-    Stmt alloca = LetStmt(
-        op->buffer_var,
-        Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
-             {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), 
device_id_),
-              cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), 
op->dtype.code()),
-              IntImm(DataType::Int(32), op->dtype.bits())}),
-        body);
+    Stmt alloca =
+        LetStmt(op->buffer_var,
+                Call(op->buffer_var.dtype(), 
Op::Get("tir.TVMBackendAllocWorkspace"),
+                     {cast(DataType::Int(32), device_type_), 
cast(DataType::Int(32), device_id_),
+                      total_bytes, IntImm(DataType::Int(32), op->dtype.code()),
+                      IntImm(DataType::Int(32), op->dtype.bits())}),
+                body);
 
     PrimExpr free_op = Call(DataType::Int(32), 
Op::Get("tir.TVMBackendFreeWorkspace"),
                             {cast(DataType::Int(32), device_type_),
diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py 
b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
index 76d6bb82cc..d224a688d2 100644
--- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
@@ -16,8 +16,8 @@
 # under the License.
 import tvm
 from tvm import te
+from tvm.script import tir as T
 import numpy as np
-from tvm import testing
 
 
 @tvm.register_func("tvm.test_matmul")
@@ -172,6 +172,23 @@ def test_call_packed_return_non_i32():
     tvm.testing.assert_allclose(a.numpy(), expected_value)
 
 
+def test_lower_overflow_int32():
+    @T.prim_func
+    def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), 
T.int64(25690112)), "float32")):
+        T.func_attr({"global_symbol": "variance4", "tir.noalias": True})
+        rxplaceholder_red = T.allocate([32], "float32", "global")
+        T_subtract = T.allocate([822083584], "float32", "global")
+        rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
+        rxplaceholder_1 = T.Buffer((T.int64(822083584),), 
data=rxplaceholder.data)
+        T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
+        for ax1, ax2 in T.grid(32, 25690112):
+            cse_var_1: T.int32 = ax1 * 25690112 + ax2
+            T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - 
rxplaceholder_red_1[ax1]
+
+    func = variance4
+    tvm.build(func, target="llvm")  # should not crash
+
+
 if __name__ == "__main__":
     test_call_packed_return_non_i32()
     test_lower_packed_func()

Reply via email to