This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3e00253b68 [TIR] Fix Primitive Rfactor DType (#15413)
3e00253b68 is described below
commit 3e00253b68e4ae92886fa36bedf992421f6c0854
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Jul 27 07:32:23 2023 +0800
[TIR] Fix Primitive Rfactor DType (#15413)
The rfactor primitive will create/rewrite two blocks, together with the
block read/write regions. However, the generated read/write region extents
are not valid when it's a int64 index. This commit fixes the issue.
---
src/tir/schedule/primitive/reduction.cc | 6 ++-
tests/python/unittest/test_tir_schedule_rfactor.py | 57 ++++++++++++++++++++++
2 files changed, 61 insertions(+), 2 deletions(-)
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index 6069f4289c..cade5457b0 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -912,7 +912,9 @@ class RFactorBlockCreator : public BaseBlockCreator {
write_regions_.reserve(old_block->writes.size());
for (const BufferRegion& write_region : old_block->writes) {
Array<Range> region = write_region->region;
- region.insert(region.begin() + factor_axis_,
Range::FromMinExtent(additional_iter_->var, 1));
+ region.insert(region.begin() + factor_axis_,
+ Range::FromMinExtent(additional_iter_->var,
+
make_const(additional_iter_->var.dtype(), 1)));
Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
ICHECK(rf_buffer.defined());
write_regions_.push_back(BufferRegion(rf_buffer.value(),
Substitute(region, var_map_)));
@@ -1005,7 +1007,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
Array<Range> region;
region.reserve(buf_load->indices.size());
for (const PrimExpr& index : buf_load->indices) {
- region.push_back(Range::FromMinExtent(index, 1));
+ region.push_back(Range::FromMinExtent(index, make_const(index.dtype(),
1)));
}
buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region)));
}
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py
b/tests/python/unittest/test_tir_schedule_rfactor.py
index c1eb04b7c3..43374d3751 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest
+
import tvm
import tvm.testing
from tvm import te, tir, topi
@@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin():
verify_trace_roundtrip(s, mod=argmin_topi)
+def test_reduction_rfactor_int64():
+ # fmt: off
+ @T.prim_func
+ def before(
+ A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ ):
+ for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
+ T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)
+ ):
+ with T.block("update"):
+ vi, vj = T.axis.remap("SS", [i0, i1])
+ vk = T.axis.R(
+ T.int64(128),
+ i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) +
i2_inner_inner,
+ )
+ with T.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
+
+ @T.prim_func
+ def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ ):
+ C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)),
"float32")
+
+ for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in
T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)):
+ with T.block("update_rf"):
+ vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer=
T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer])
+ with T.init():
+ C_rf[vi2_inner_inner, vi, vj] = 0.0
+ C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj]
+ (
+ A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer *
T.int64(4))) + vi2_inner_inner)]
+ * B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer *
T.int64(4))) + vi2_inner_inner)]
+ )
+
+ for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128),
T.int64(4)):
+ with T.block("update"):
+ vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS",
[i2_inner_inner_1, i0_1, i1_1])
+ with T.init():
+ C[vi_1, vj_1] = 0.0
+ C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1,
vj_1]
+ # fmt: on
+
+ s = tir.Schedule(before, debug_mask="all")
+ update = s.get_block("update")
+ _, _, _, _, kii = s.get_loops(update)
+ rf_block = s.rfactor(kii, 0)
+ assert_structural_equal_ignore_global_symbol(s.mod["main"], expected)
+ assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
+ assert s.get(update).same_as(s.get(s.get_block("update")))
+ verify_trace_roundtrip(s, mod=before)
+
+
if __name__ == "__main__":
tvm.testing.main()