This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff3a48e [TIR] Fix Tensorization IR-Comparator for Annotations (#10498)
ff3a48e is described below
commit ff3a48e9c0af1ef5c9395075858755ea5e3c93f5
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 6 01:43:26 2022 +0800
[TIR] Fix Tensorization IR-Comparator for Annotations (#10498)
This PR fixes the way of comparison in which the tensorization
IR-comparator deals with annotations.
Prior to this PR, the comparator requires the annotation values from LHS
and RHS to be exactly the same, which is, in fact, never possible. And this PR
removes this comparison requirement (with a regression unit test).
```c++
bool TensorizeComparator::CompareAnnotation(const std::pair<String,
ObjectRef>& lhs,
const std::pair<String,
ObjectRef>& rhs) {
if (lhs.first != rhs.first) return false;
if (!lhs.second.same_as(rhs.second)) return false; // <== The values
would never be the same.
// Thus this line
should be removed.
return VisitExpr(Downcast<PrimExpr>(lhs.second),
Downcast<PrimExpr>(rhs.second));
}
```
---
src/tir/schedule/ir_comparator.cc | 1 -
.../python/unittest/test_tir_schedule_tensorize.py | 104 +++++++++++++++++++++
2 files changed, 104 insertions(+), 1 deletion(-)
diff --git a/src/tir/schedule/ir_comparator.cc
b/src/tir/schedule/ir_comparator.cc
index 3e61e95..cdd17d2 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -216,7 +216,6 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const
Var& rhs) {
bool TensorizeComparator::CompareAnnotation(const std::pair<String,
ObjectRef>& lhs,
const std::pair<String,
ObjectRef>& rhs) {
if (lhs.first != rhs.first) return false;
- if (!lhs.second.same_as(rhs.second)) return false;
return VisitExpr(Downcast<PrimExpr>(lhs.second),
Downcast<PrimExpr>(rhs.second));
}
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py
b/tests/python/unittest/test_tir_schedule_tensorize.py
index 401a39f..5cef8d6 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -365,10 +365,99 @@ def tensorized_batch_matmul_outer_product(
)
[email protected]_func
+def annotated_mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
+ T.writes(C[0 : 16, 0 : 16])
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("update"):
+ T.block_attr({"test_annotation": True})
+ vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+ C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
+
+
[email protected]_func
+def annotated_matmul(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+) -> None:
+ for i, j, k in T.grid(128, 128, 128):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.block_attr({"test_annotation": True})
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]_func
+def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+ C = T.match_buffer(c, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ B = T.match_buffer(b, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ A = T.match_buffer(a, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+
+ for i_outer, j_outer in T.grid(8, 8):
+ for i_inner_init, j_inner_init in T.grid(16, 16):
+ with T.block("init"):
+ vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init))
+ vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init))
+ T.block_attr({"test_annotation": True})
+ C[vi_init, vj_init] = T.float32(0)
+ for k_outer in T.grid(8):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
+ T.reads(
+ [
+ C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+ A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+ B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+ ]
+ )
+ T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ A_elem_offset = T.var("int32")
+ B_elem_offset = T.var("int32")
+ C_elem_offset = T.var("int32")
+ A_sub = T.match_buffer(
+ A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+ [16, 16],
+ elem_offset=A_elem_offset,
+ )
+ B_sub = T.match_buffer(
+ B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+ [16, 16],
+ elem_offset=B_elem_offset,
+ )
+ C_sub = T.match_buffer(
+ C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+ [16, 16],
+ elem_offset=C_elem_offset,
+ )
+ T.evaluate(
+ T.tvm_mma_sync(
+ C_sub.data,
+ T.floordiv(C_sub.elem_offset, 256),
+ A_sub.data,
+ T.floordiv(A_sub.elem_offset, 256),
+ B_sub.data,
+ T.floordiv(B_sub.elem_offset, 256),
+ C_sub.data,
+ T.floordiv(C_sub.elem_offset, 256),
+ dtype="handle",
+ )
+ )
+
+
# fmt: off
# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+tir.TensorIntrin.register("test_annotated_mma_intrin", annotated_mma_desc,
mma_intrin)
tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc,
dot_product_intrin)
tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc,
outer_product_intrin)
@@ -427,5 +516,20 @@ def test_tensorize_outer_product():
verify_trace_roundtrip(sch=s, mod=func)
+def test_tensorize_with_annotation():
+ func = annotated_matmul
+ s = tir.Schedule(func, debug_mask="all")
+ update = s.get_block("update")
+ i, j, k = s.get_loops(update)
+ io, ii = s.split(i, factors=[None, 16])
+ jo, ji = s.split(j, factors=[None, 16])
+ ko, ki = s.split(k, factors=[None, 16])
+ s.reorder(io, jo, ko, ii, ji, ki)
+ s.decompose_reduction(update, ko)
+ s.tensorize(ii, "test_annotated_mma_intrin")
+ tvm.ir.assert_structural_equal(annotated_tensorized_matmul, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))