This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f3fde8121b [TIR] [Schedule] Fix decompose_padding bug with dtypes
(#15050)
f3fde8121b is described below
commit f3fde8121bd2fd25066ea9ec1191880b67dee268
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Thu Jun 8 12:36:20 2023 +0530
[TIR] [Schedule] Fix decompose_padding bug with dtypes (#15050)
Ran into a type mismatch error when the primfunc uses int64 dimensions,
but the extent introduced in a particular case was just int32.
---
src/tir/schedule/primitive/decompose_padding.cc | 2 +-
.../test_tir_schedule_decompose_padding.py | 41 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 1 deletion(-)
diff --git a/src/tir/schedule/primitive/decompose_padding.cc
b/src/tir/schedule/primitive/decompose_padding.cc
index 1743a34088..50b978f012 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -168,7 +168,7 @@ class PaddingInfoAnalyzer {
}
for (const arith::IterSumExpr& sum : res->indices) {
if (sum->args.empty()) {
- region.push_back(Range::FromMinExtent(sum->base, 1));
+ region.push_back(Range::FromMinExtent(sum->base,
IntImm(sum->base.dtype(), /* value */ 1)));
} else {
ICHECK_EQ(sum->args.size(), 1U);
if (!analyzer_->CanProveEqual(sum->args[0]->scale, 1)) {
diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py
b/tests/python/unittest/test_tir_schedule_decompose_padding.py
index e33cfdbd34..15ed194328 100644
--- a/tests/python/unittest/test_tir_schedule_decompose_padding.py
+++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py
@@ -41,6 +41,47 @@ def check_decompose_padding(origin, scheduled, expected,
check_run=False):
tvm.testing.assert_allclose(y0.numpy(), y1.numpy())
+def test_int64_indices_batch_decompose_padding():
+ @T.prim_func
+ def before_decompose(
+ x: T.Buffer((T.int64(1), T.int64(128), T.int64(128)), "int32"),
+ y: T.Buffer((T.int64(1), T.int64(140), T.int64(128)), "int32"),
+ ):
+ for b, i, j in T.grid(T.int64(1), T.int64(140), T.int64(128)):
+ with T.block("block"):
+ vb, vi, vj = T.axis.remap("SSS", [b, i, j])
+ y[vb, vi, vj] = T.if_then_else(vi < T.int64(128), x[vb, vi,
vj], 0)
+
+ @T.prim_func
+ def after_decompose(
+ x: T.Buffer((T.int64(1), T.int64(128), T.int64(128)), "int32"),
+ y: T.Buffer((T.int64(1), T.int64(140), T.int64(128)), "int32"),
+ ):
+ # with T.block("root"):
+ for b, i in T.grid(T.int64(1), T.int64(140)):
+ for j in range(T.int64(128)):
+ with T.block("block_pad_const"):
+ vb = T.axis.spatial(T.int64(1), T.int64(0))
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads()
+ T.writes(y[vb, vi, vj])
+ y[vb, vi, vj] = 0
+ for j in range(T.int64(128)):
+ with T.block("block"):
+ vb = T.axis.spatial(T.int64(1), T.int64(0))
+ vi = T.axis.spatial(T.int64(128), i)
+ vj = T.axis.spatial(T.int64(128), j)
+ T.where(i < T.int64(128))
+ T.reads(x[vb, vi, vj])
+ T.writes(y[vb, vi, vj])
+ y[vb, vi, vj] = x[vb, vi, vj]
+
+ sch = tir.Schedule(before_decompose, debug_mask="all")
+ block = sch.get_block("block")
+ sch.decompose_padding(block, sch.get_loops(block)[2])
+ check_decompose_padding(before_decompose, sch.mod["main"],
after_decompose, check_run=False)
+
+
def test_1d_decompose_padding():
@T.prim_func
def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")):