This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e00253b68 [TIR] Fix Primitive Rfactor DType (#15413)
3e00253b68 is described below

commit 3e00253b68e4ae92886fa36bedf992421f6c0854
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Jul 27 07:32:23 2023 +0800

    [TIR] Fix Primitive Rfactor DType (#15413)
    
    The rfactor primitive will create/rewrite two blocks, together with the
    block read/write regions. However, the generated read/write region extents
    are not valid when it's a int64 index. This commit fixes the issue.
---
 src/tir/schedule/primitive/reduction.cc            |  6 ++-
 tests/python/unittest/test_tir_schedule_rfactor.py | 57 ++++++++++++++++++++++
 2 files changed, 61 insertions(+), 2 deletions(-)

diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index 6069f4289c..cade5457b0 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -912,7 +912,9 @@ class RFactorBlockCreator : public BaseBlockCreator {
     write_regions_.reserve(old_block->writes.size());
     for (const BufferRegion& write_region : old_block->writes) {
       Array<Range> region = write_region->region;
-      region.insert(region.begin() + factor_axis_, 
Range::FromMinExtent(additional_iter_->var, 1));
+      region.insert(region.begin() + factor_axis_,
+                    Range::FromMinExtent(additional_iter_->var,
+                                         
make_const(additional_iter_->var.dtype(), 1)));
       Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
       ICHECK(rf_buffer.defined());
       write_regions_.push_back(BufferRegion(rf_buffer.value(), 
Substitute(region, var_map_)));
@@ -1005,7 +1007,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
       Array<Range> region;
       region.reserve(buf_load->indices.size());
       for (const PrimExpr& index : buf_load->indices) {
-        region.push_back(Range::FromMinExtent(index, 1));
+        region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 
1)));
       }
       buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region)));
     }
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py 
b/tests/python/unittest/test_tir_schedule_rfactor.py
index c1eb04b7c3..43374d3751 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
 import pytest
+
 import tvm
 import tvm.testing
 from tvm import te, tir, topi
@@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin():
     verify_trace_roundtrip(s, mod=argmin_topi)
 
 
+def test_reduction_rfactor_int64():
+    # fmt: off
+    @T.prim_func
+    def before(
+        A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+    ):
+        for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
+            T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)
+        ):
+            with T.block("update"):
+                vi, vj = T.axis.remap("SS", [i0, i1])
+                vk = T.axis.R(
+                    T.int64(128),
+                    i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) + 
i2_inner_inner,
+                )
+                with T.init():
+                    C[vi, vj] = 0.0
+                C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
+
+    @T.prim_func
+    def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+    ):
+        C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), 
"float32")
+
+        for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in 
T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)):
+            with T.block("update_rf"):
+                vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= 
T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer])
+                with T.init():
+                    C_rf[vi2_inner_inner, vi, vj] = 0.0
+                C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] 
+ (
+                    A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * 
T.int64(4))) + vi2_inner_inner)]
+                    * B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * 
T.int64(4))) + vi2_inner_inner)]
+                )
+
+        for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), 
T.int64(4)):
+            with T.block("update"):
+                vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", 
[i2_inner_inner_1, i0_1, i1_1])
+                with T.init():
+                    C[vi_1, vj_1] = 0.0
+                C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, 
vj_1]
+    # fmt: on
+
+    s = tir.Schedule(before, debug_mask="all")
+    update = s.get_block("update")
+    _, _, _, _, kii = s.get_loops(update)
+    rf_block = s.rfactor(kii, 0)
+    assert_structural_equal_ignore_global_symbol(s.mod["main"], expected)
+    assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
+    assert s.get(update).same_as(s.get(s.get_block("update")))
+    verify_trace_roundtrip(s, mod=before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to