This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff3a48e  [TIR] Fix Tensorization IR-Comparator for Annotations (#10498)
ff3a48e is described below

commit ff3a48e9c0af1ef5c9395075858755ea5e3c93f5
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 6 01:43:26 2022 +0800

    [TIR] Fix Tensorization IR-Comparator for Annotations (#10498)
    
    This PR fixes the way of comparison in which the tensorization 
IR-comparator deals with annotations.
    
    Prior to this PR, the comparator requires the annotation values from LHS 
and RHS to be exactly the same, which is, in fact, never possible. And this PR 
removes this comparison requirement (with a regression unit test).
    
    ```c++
    bool TensorizeComparator::CompareAnnotation(const std::pair<String, 
ObjectRef>& lhs,
                                                const std::pair<String, 
ObjectRef>& rhs) {
      if (lhs.first != rhs.first) return false;
      if (!lhs.second.same_as(rhs.second)) return false;  // <== The values 
would never be the same.
                                                          //     Thus this line 
should be removed.
      return VisitExpr(Downcast<PrimExpr>(lhs.second), 
Downcast<PrimExpr>(rhs.second));
    }
    ```
---
 src/tir/schedule/ir_comparator.cc                  |   1 -
 .../python/unittest/test_tir_schedule_tensorize.py | 104 +++++++++++++++++++++
 2 files changed, 104 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/ir_comparator.cc 
b/src/tir/schedule/ir_comparator.cc
index 3e61e95..cdd17d2 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -216,7 +216,6 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const 
Var& rhs) {
 bool TensorizeComparator::CompareAnnotation(const std::pair<String, 
ObjectRef>& lhs,
                                             const std::pair<String, 
ObjectRef>& rhs) {
   if (lhs.first != rhs.first) return false;
-  if (!lhs.second.same_as(rhs.second)) return false;
   return VisitExpr(Downcast<PrimExpr>(lhs.second), 
Downcast<PrimExpr>(rhs.second));
 }
 
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py 
b/tests/python/unittest/test_tir_schedule_tensorize.py
index 401a39f..5cef8d6 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -365,10 +365,99 @@ def tensorized_batch_matmul_outer_product(
             )
 
 
[email protected]_func
+def annotated_mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
+        T.writes(C[0 : 16, 0 : 16])
+        for i, j, k in T.grid(16, 16, 16):
+            with T.block("update"):
+                T.block_attr({"test_annotation": True})
+                vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+                C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
+
+
[email protected]_func
+def annotated_matmul(
+    A: T.Buffer[(128, 128), "float32"],
+    B: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(128, 128), "float32"],
+) -> None:
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("update"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            T.block_attr({"test_annotation": True})
+            with T.init():
+                C[vi, vj] = T.float32(0)
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]_func
+def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+
+    for i_outer, j_outer in T.grid(8, 8):
+        for i_inner_init, j_inner_init in T.grid(16, 16):
+            with T.block("init"):
+                vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init))
+                vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init))
+                T.block_attr({"test_annotation": True})
+                C[vi_init, vj_init] = T.float32(0)
+        for k_outer in T.grid(8):
+            with T.block("update"):
+                vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
+                T.reads(
+                    [
+                        C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+                        A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+                        B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+                    ]
+                )
+                T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                A_elem_offset = T.var("int32")
+                B_elem_offset = T.var("int32")
+                C_elem_offset = T.var("int32")
+                A_sub = T.match_buffer(
+                    A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+                    [16, 16],
+                    elem_offset=A_elem_offset,
+                )
+                B_sub = T.match_buffer(
+                    B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+                    [16, 16],
+                    elem_offset=B_elem_offset,
+                )
+                C_sub = T.match_buffer(
+                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+                    [16, 16],
+                    elem_offset=C_elem_offset,
+                )
+                T.evaluate(
+                    T.tvm_mma_sync(
+                        C_sub.data,
+                        T.floordiv(C_sub.elem_offset, 256),
+                        A_sub.data,
+                        T.floordiv(A_sub.elem_offset, 256),
+                        B_sub.data,
+                        T.floordiv(B_sub.elem_offset, 256),
+                        C_sub.data,
+                        T.floordiv(C_sub.elem_offset, 256),
+                        dtype="handle",
+                    )
+                )
+
+
 # fmt: off
 # pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
 
 tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+tir.TensorIntrin.register("test_annotated_mma_intrin", annotated_mma_desc, 
mma_intrin)
 tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, 
dot_product_intrin)
 tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, 
outer_product_intrin)
 
@@ -427,5 +516,20 @@ def test_tensorize_outer_product():
     verify_trace_roundtrip(sch=s, mod=func)
 
 
+def test_tensorize_with_annotation():
+    func = annotated_matmul
+    s = tir.Schedule(func, debug_mask="all")
+    update = s.get_block("update")
+    i, j, k = s.get_loops(update)
+    io, ii = s.split(i, factors=[None, 16])
+    jo, ji = s.split(j, factors=[None, 16])
+    ko, ki = s.split(k, factors=[None, 16])
+    s.reorder(io, jo, ko, ii, ji, ki)
+    s.decompose_reduction(update, ko)
+    s.tensorize(ii, "test_annotated_mma_intrin")
+    tvm.ir.assert_structural_equal(annotated_tensorized_matmul, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to