This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 68f9509b0c [TIR] Fix int64 dtype mismatch in Reindex (#12934)
68f9509b0c is described below

commit 68f9509b0cece96b57581c3c21a145581b5a0365
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Sep 28 19:11:17 2022 -0700

    [TIR] Fix int64 dtype mismatch in Reindex (#12934)
---
 src/tir/schedule/primitive/cache_read_write.cc     | 16 ++++---
 tests/python/unittest/test_tir_schedule_reindex.py | 51 ++++++++++++++++++++++
 2 files changed, 60 insertions(+), 7 deletions(-)

diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index c76e6abaeb..e9583adbba 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -196,14 +196,16 @@ Block MakeReIndexStage(const Block& block, 
CacheStageInfo* info,
   // Step 1: Create block iters, access regions of the reindex block, and 
accessing indices to the
   // reindex buffer.
   for (const IterVar& iter : block->iter_vars) {
-    Var var("v" + std::to_string(new_block_iters.size()));
+    Var var("v" + std::to_string(new_block_iters.size()), iter->var->dtype);
     bool used = covered.count(iter->var);
-    new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom : 
Range::FromMinExtent(0, 1),
-                                      /*var=*/var,
-                                      /*IterVarType=*/kDataPar));
+    new_block_iters.push_back(
+        IterVar(/*dom=*/used ? iter->dom
+                             : Range::FromMinExtent(IntImm(var->dtype, 0), 
IntImm(var->dtype, 1)),
+                /*var=*/var,
+                /*IterVarType=*/kDataPar));
     if (used) {
       reindex_indices.push_back(var);
-      reindex_region.push_back(Range::FromMinExtent(var, 1));
+      reindex_region.push_back(Range::FromMinExtent(var, IntImm(var->dtype, 
1)));
     }
     block_var_replace_map[iter->var] = var;
   }
@@ -254,7 +256,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* 
info,
   std::vector<Var> loop_vars;         // loop variables
   std::vector<PrimExpr> iter_values;  // bindings in block realize
   for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
-    Var loop_var("ax" + std::to_string(loop_vars.size()));
+    Var loop_var("ax" + std::to_string(loop_vars.size()), 
block->iter_vars[i]->var->dtype);
     loop_vars.push_back(loop_var);
     iter_values.push_back(loop_var);
   }
@@ -920,7 +922,7 @@ class ReIndexRewriter : public StmtExprMutator {
       for (const IterVar& iter : block->iter_vars) {
         if (covered_.count(iter->var)) {
           indices_.push_back(iter->var);
-          region_.push_back(Range::FromMinExtent(iter->var, 1));
+          region_.push_back(Range::FromMinExtent(iter->var, 
IntImm(iter->var->dtype, 1)));
         }
       }
       Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
diff --git a/tests/python/unittest/test_tir_schedule_reindex.py 
b/tests/python/unittest/test_tir_schedule_reindex.py
index c6776b0c8a..47b8b5cb88 100644
--- a/tests/python/unittest/test_tir_schedule_reindex.py
+++ b/tests/python/unittest/test_tir_schedule_reindex.py
@@ -168,6 +168,48 @@ def multiple_read(A: T.Buffer[(128, 128), "float32"], B: 
T.Buffer[(128, 128), "f
             B[vi, vj] = A[vj, vi] + A[vi, vj]
 
 
[email protected]_func
+def mixed_dtype(
+    p0: T.Buffer[(T.int64(2), 1280), "float16"],
+    p1: T.Buffer[(1280, 1280), "float16"],
+    T_matmul_NT: T.Buffer[(T.int64(2), 1280), "float16"],
+) -> None:
+    for i0, i1, i2 in T.grid(T.int64(2), 1280, 1280):
+        with T.block("T_matmul_NT"):
+            i = T.axis.spatial(T.int64(2), i0)
+            j, k = T.axis.remap("SR", [i1, i2])
+            T.reads(p0[i, k], p1[j, k])
+            T.writes(T_matmul_NT[i, j])
+            with T.init():
+                T_matmul_NT[i, j] = T.float16(0)
+            T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k]
+
+
[email protected]_func
+def mixed_dtype_reindex_write(
+    p0: T.Buffer[(T.int64(2), 1280), "float16"],
+    p1: T.Buffer[(1280, 1280), "float16"],
+    T_matmul_NT: T.Buffer[(T.int64(2), 1280), "float16"],
+) -> None:
+    T_matmul_NT_reindex = T.alloc_buffer([T.int64(2), 1280], dtype="float16")
+    for i0, i1, i2 in T.grid(T.int64(2), 1280, 1280):
+        with T.block("T_matmul_NT"):
+            i = T.axis.spatial(T.int64(2), i0)
+            j, k = T.axis.remap("SR", [i1, i2])
+            T.reads(p0[i, k], p1[j, k])
+            T.writes(T_matmul_NT_reindex[i, j])
+            with T.init():
+                T_matmul_NT_reindex[i, j] = T.float16(0)
+            T_matmul_NT_reindex[i, j] = T_matmul_NT_reindex[i, j] + p0[i, k] * 
p1[j, k]
+    for ax0, ax1, ax2 in T.grid(T.int64(2), 1280, 1):
+        with T.block("T_matmul_NT_reindex"):
+            v0 = T.axis.spatial(T.int64(2), ax0)
+            v1, v2 = T.axis.remap("SS", [ax1, ax2])
+            T.reads(T_matmul_NT_reindex[v0, v1])
+            T.writes(T_matmul_NT[v0, v1])
+            T_matmul_NT[v0, v1] = T_matmul_NT_reindex[v0, v1]
+
+
 use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, 
"block_name": True})
 use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False, 
"buffer_name": True})
 
@@ -207,5 +249,14 @@ def test_reindex_fail_multiple_read(use_block_name, 
use_buffer_name):
         sch.reindex(block, buf)
 
 
+def test_reindex_mixed_dtype(use_block_name, use_buffer_name):
+    sch = tir.Schedule(mixed_dtype)
+    block = "T_matmul_NT" if use_block_name else sch.get_block("T_matmul_NT")
+    buf = "T_matmul_NT" if use_buffer_name else ("write", 0)
+    sch.reindex(block, buf)
+    tvm.ir.assert_structural_equal(mixed_dtype_reindex_write, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=mixed_dtype)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to