This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 68f9509b0c [TIR] Fix int64 dtype mismatch in Reindex (#12934)
68f9509b0c is described below
commit 68f9509b0cece96b57581c3c21a145581b5a0365
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Sep 28 19:11:17 2022 -0700
[TIR] Fix int64 dtype mismatch in Reindex (#12934)
---
src/tir/schedule/primitive/cache_read_write.cc | 16 ++++---
tests/python/unittest/test_tir_schedule_reindex.py | 51 ++++++++++++++++++++++
2 files changed, 60 insertions(+), 7 deletions(-)
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index c76e6abaeb..e9583adbba 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -196,14 +196,16 @@ Block MakeReIndexStage(const Block& block,
CacheStageInfo* info,
// Step 1: Create block iters, access regions of the reindex block, and
accessing indices to the
// reindex buffer.
for (const IterVar& iter : block->iter_vars) {
- Var var("v" + std::to_string(new_block_iters.size()));
+ Var var("v" + std::to_string(new_block_iters.size()), iter->var->dtype);
bool used = covered.count(iter->var);
- new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom :
Range::FromMinExtent(0, 1),
- /*var=*/var,
- /*IterVarType=*/kDataPar));
+ new_block_iters.push_back(
+ IterVar(/*dom=*/used ? iter->dom
+ : Range::FromMinExtent(IntImm(var->dtype, 0),
IntImm(var->dtype, 1)),
+ /*var=*/var,
+ /*IterVarType=*/kDataPar));
if (used) {
reindex_indices.push_back(var);
- reindex_region.push_back(Range::FromMinExtent(var, 1));
+ reindex_region.push_back(Range::FromMinExtent(var, IntImm(var->dtype,
1)));
}
block_var_replace_map[iter->var] = var;
}
@@ -254,7 +256,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo*
info,
std::vector<Var> loop_vars; // loop variables
std::vector<PrimExpr> iter_values; // bindings in block realize
for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
- Var loop_var("ax" + std::to_string(loop_vars.size()));
+ Var loop_var("ax" + std::to_string(loop_vars.size()),
block->iter_vars[i]->var->dtype);
loop_vars.push_back(loop_var);
iter_values.push_back(loop_var);
}
@@ -920,7 +922,7 @@ class ReIndexRewriter : public StmtExprMutator {
for (const IterVar& iter : block->iter_vars) {
if (covered_.count(iter->var)) {
indices_.push_back(iter->var);
- region_.push_back(Range::FromMinExtent(iter->var, 1));
+ region_.push_back(Range::FromMinExtent(iter->var,
IntImm(iter->var->dtype, 1)));
}
}
Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
diff --git a/tests/python/unittest/test_tir_schedule_reindex.py
b/tests/python/unittest/test_tir_schedule_reindex.py
index c6776b0c8a..47b8b5cb88 100644
--- a/tests/python/unittest/test_tir_schedule_reindex.py
+++ b/tests/python/unittest/test_tir_schedule_reindex.py
@@ -168,6 +168,48 @@ def multiple_read(A: T.Buffer[(128, 128), "float32"], B:
T.Buffer[(128, 128), "f
B[vi, vj] = A[vj, vi] + A[vi, vj]
[email protected]_func
+def mixed_dtype(
+ p0: T.Buffer[(T.int64(2), 1280), "float16"],
+ p1: T.Buffer[(1280, 1280), "float16"],
+ T_matmul_NT: T.Buffer[(T.int64(2), 1280), "float16"],
+) -> None:
+ for i0, i1, i2 in T.grid(T.int64(2), 1280, 1280):
+ with T.block("T_matmul_NT"):
+ i = T.axis.spatial(T.int64(2), i0)
+ j, k = T.axis.remap("SR", [i1, i2])
+ T.reads(p0[i, k], p1[j, k])
+ T.writes(T_matmul_NT[i, j])
+ with T.init():
+ T_matmul_NT[i, j] = T.float16(0)
+ T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k]
+
+
[email protected]_func
+def mixed_dtype_reindex_write(
+ p0: T.Buffer[(T.int64(2), 1280), "float16"],
+ p1: T.Buffer[(1280, 1280), "float16"],
+ T_matmul_NT: T.Buffer[(T.int64(2), 1280), "float16"],
+) -> None:
+ T_matmul_NT_reindex = T.alloc_buffer([T.int64(2), 1280], dtype="float16")
+ for i0, i1, i2 in T.grid(T.int64(2), 1280, 1280):
+ with T.block("T_matmul_NT"):
+ i = T.axis.spatial(T.int64(2), i0)
+ j, k = T.axis.remap("SR", [i1, i2])
+ T.reads(p0[i, k], p1[j, k])
+ T.writes(T_matmul_NT_reindex[i, j])
+ with T.init():
+ T_matmul_NT_reindex[i, j] = T.float16(0)
+ T_matmul_NT_reindex[i, j] = T_matmul_NT_reindex[i, j] + p0[i, k] *
p1[j, k]
+ for ax0, ax1, ax2 in T.grid(T.int64(2), 1280, 1):
+ with T.block("T_matmul_NT_reindex"):
+ v0 = T.axis.spatial(T.int64(2), ax0)
+ v1, v2 = T.axis.remap("SS", [ax1, ax2])
+ T.reads(T_matmul_NT_reindex[v0, v1])
+ T.writes(T_matmul_NT[v0, v1])
+ T_matmul_NT[v0, v1] = T_matmul_NT_reindex[v0, v1]
+
+
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False,
"block_name": True})
use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False,
"buffer_name": True})
@@ -207,5 +249,14 @@ def test_reindex_fail_multiple_read(use_block_name,
use_buffer_name):
sch.reindex(block, buf)
+def test_reindex_mixed_dtype(use_block_name, use_buffer_name):
+ sch = tir.Schedule(mixed_dtype)
+ block = "T_matmul_NT" if use_block_name else sch.get_block("T_matmul_NT")
+ buf = "T_matmul_NT" if use_buffer_name else ("write", 0)
+ sch.reindex(block, buf)
+ tvm.ir.assert_structural_equal(mixed_dtype_reindex_write, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=mixed_dtype)
+
+
if __name__ == "__main__":
tvm.testing.main()