This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2cae905a72 [TIR] Support pattern matching argmax/argmin generated by 
TOPI (#12827)
2cae905a72 is described below

commit 2cae905a727930eaaeb59085393eef1e1421fc20
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Sep 16 21:11:31 2022 -0400

    [TIR] Support pattern matching argmax/argmin generated by TOPI (#12827)
    
    This PR introduces two reducers to TIR reduction part, so that rfactor and 
cross-thread reduction can be applied to those functions who contains 
argmax/argmin computation generated by TOPI.
---
 src/tir/schedule/primitive/reduction.cc            | 134 +++++++++++-------
 tests/python/unittest/test_tir_schedule_rfactor.py | 156 ++++++++++++++++++++-
 2 files changed, 233 insertions(+), 57 deletions(-)

diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index 2dc47fa15b..dd2bcf727c 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -297,60 +297,86 @@ StmtSRef DecomposeReduction(ScheduleState self, const 
StmtSRef& block_sref,
  */
 struct ReducerRegistry {
   ReducerRegistry()
-      : reducer_getters{CreateReducerGetter(
-                            /*n_buffers=*/1,
-                            [](const Array<Var>& x, const Array<Var>& y) {
-                              return Array<PrimExpr>{x[0] + y[0]};
-                            },
-                            [](const Array<PrimExpr>& values) {
-                              return 
Array<PrimExpr>{make_const(values[0]->dtype, 0)};
-                            }),
-                        CreateReducerGetter(
-                            /*n_buffers=*/1,
-                            [](const Array<Var>& x, const Array<Var>& y) {
-                              return Array<PrimExpr>{x[0] * y[0]};
-                            },
-                            [](const Array<PrimExpr>& values) {
-                              return 
Array<PrimExpr>{make_const(values[0]->dtype, 1)};
-                            }),
-                        CreateReducerGetter(
-                            /*n_buffers=*/1,
-                            [](const Array<Var>& x, const Array<Var>& y) {
-                              return Array<PrimExpr>{min(x[0], y[0])};
-                            },
-                            [](const Array<PrimExpr>& values) {
-                              return 
Array<PrimExpr>{max_value(values[0]->dtype)};
-                            }),
-                        CreateReducerGetter(
-                            /*n_buffers=*/1,
-                            [](const Array<Var>& x, const Array<Var>& y) {
-                              return Array<PrimExpr>{max(x[0], y[0])};
-                            },
-                            [](const Array<PrimExpr>& values) {
-                              return 
Array<PrimExpr>{min_value(values[0]->dtype)};
-                            }),
-                        CreateReducerGetter(
-                            /*n_buffers=*/2,
-                            [](const Array<Var>& x, const Array<Var>& y) {
-                              PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
-                              PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
-                              return Array<PrimExpr>{idx, val};
-                            },
-                            [](const Array<PrimExpr>& values) {
-                              return 
Array<PrimExpr>{make_const(values[0]->dtype, -1),
-                                                     
min_value(values[1]->dtype)};
-                            }),
-                        CreateReducerGetter(
-                            /*n_buffers=*/2,
-                            [](const Array<Var>& x, const Array<Var>& y) {
-                              PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
-                              PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
-                              return Array<PrimExpr>{idx, val};
-                            },
-                            [](const Array<PrimExpr>& values) {
-                              return 
Array<PrimExpr>{make_const(values[0]->dtype, -1),
-                                                     
max_value(values[1]->dtype)};
-                            })} {}
+      : reducer_getters{
+            CreateReducerGetter(
+                /*n_buffers=*/1,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  return Array<PrimExpr>{x[0] + y[0]};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/1,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  return Array<PrimExpr>{x[0] * y[0]};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/1,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  return Array<PrimExpr>{min(x[0], y[0])};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{max_value(values[0]->dtype)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/1,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  return Array<PrimExpr>{max(x[0], y[0])};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{min_value(values[0]->dtype)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/2,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
+                  PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
+                  return Array<PrimExpr>{idx, val};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+                                         min_value(values[1]->dtype)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/2,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  PrimExpr idx =
+                      Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), 
less(x[0], y[0]))),
+                             x[0], y[0]);
+                  PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]);
+                  return Array<PrimExpr>{idx, val};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+                                         min_value(values[1]->dtype)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/2,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
+                  PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
+                  return Array<PrimExpr>{idx, val};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+                                         max_value(values[1]->dtype)};
+                }),
+            CreateReducerGetter(
+                /*n_buffers=*/2,
+                [](const Array<Var>& x, const Array<Var>& y) {
+                  PrimExpr idx = Select(
+                      Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], 
y[0]))), x[0], y[0]);
+                  PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]);
+                  return Array<PrimExpr>{idx, val};
+                },
+                [](const Array<PrimExpr>& values) {
+                  return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+                                         max_value(values[1]->dtype)};
+                })} {}
 
   static void RegisterReducer(
       int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> 
combiner_getter,
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py 
b/tests/python/unittest/test_tir_schedule_rfactor.py
index f6db79f3ed..964fe772d8 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -15,12 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
-import sys
-
 import pytest
 import tvm
 import tvm.testing
-from tvm import tir
+from tvm import te, tir, topi
 from tvm.script import tir as T
 from tvm.tir.schedule.testing import verify_trace_roundtrip
 
@@ -1133,6 +1131,128 @@ def argmin_split_rfactor(
             argmin_v1[i] = v_argmin_v1
 
 
[email protected]_func
+def argmax_topi_rfactor(
+    placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, 
"int32"]
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
+    placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
+    placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
+    placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
+    for i0, i1_0, i1_1 in T.grid(1, 4, 8):
+        with T.block("placeholder_red_temp_rf"):
+            vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
+            T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
+            T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], 
placeholder_red_temp_v1_rf[ax0, vi1_1])
+            with T.init():
+                placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
+                placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648
+            v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
+                placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, 
vi1_0 * 8 + vi1_1]
+                or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, 
vi1_0 * 8 + vi1_1]
+                and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
+                placeholder_red_temp_v0_rf[ax0, vi1_1],
+                vi1_0 * 8 + vi1_1,
+            )
+            v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
+                placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, 
vi1_0 * 8 + vi1_1],
+                placeholder_red_temp_v1_rf[ax0, vi1_1],
+                placeholder[ax0, vi1_0 * 8 + vi1_1],
+            )
+            placeholder_red_temp_v0_rf[ax0, vi1_1] = 
v_placeholder_red_temp_v0_rf
+            placeholder_red_temp_v1_rf[ax0, vi1_1] = 
v_placeholder_red_temp_v1_rf
+    for i0, i1_1 in T.grid(1, 8):
+        with T.block("placeholder_red_temp"):
+            vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
+            T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], 
placeholder_red_temp_v1_rf[ax0, vi1_1])
+            T.writes(placeholder_red_temp_v0[ax0], 
placeholder_red_temp_v1[ax0])
+            with T.init():
+                placeholder_red_temp_v0[ax0] = -1
+                placeholder_red_temp_v1[ax0] = -2147483648
+            v_placeholder_red_temp_v0: T.int32 = T.Select(
+                placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, 
vi1_1]
+                or placeholder_red_temp_v1[ax0] == 
placeholder_red_temp_v1_rf[ax0, vi1_1]
+                and placeholder_red_temp_v0[ax0] < 
placeholder_red_temp_v0_rf[ax0, vi1_1],
+                placeholder_red_temp_v0[ax0],
+                placeholder_red_temp_v0_rf[ax0, vi1_1],
+            )
+            v_placeholder_red_temp_v1: T.int32 = T.Select(
+                placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, 
vi1_1],
+                placeholder_red_temp_v1[ax0],
+                placeholder_red_temp_v1_rf[ax0, vi1_1],
+            )
+            placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
+            placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
+    for i0 in T.serial(1):
+        with T.block("placeholder_red"):
+            ax0 = T.axis.spatial(1, i0)
+            T.reads(placeholder_red_temp_v0[ax0])
+            T.writes(placeholder_red[ax0])
+            placeholder_red[ax0] = placeholder_red_temp_v0[ax0]
+
+
[email protected]_func
+def argmin_topi_rfactor(
+    placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, 
"int32"]
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
+    placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
+    placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
+    placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
+    for i0, i1_0, i1_1 in T.grid(1, 4, 8):
+        with T.block("placeholder_red_temp_rf"):
+            vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
+            T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
+            T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], 
placeholder_red_temp_v1_rf[ax0, vi1_1])
+            with T.init():
+                placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
+                placeholder_red_temp_v1_rf[ax0, vi1_1] = 2147483647
+            v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
+                placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, 
vi1_0 * 8 + vi1_1]
+                or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, 
vi1_0 * 8 + vi1_1]
+                and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
+                placeholder_red_temp_v0_rf[ax0, vi1_1],
+                vi1_0 * 8 + vi1_1,
+            )
+            v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
+                placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, 
vi1_0 * 8 + vi1_1],
+                placeholder_red_temp_v1_rf[ax0, vi1_1],
+                placeholder[ax0, vi1_0 * 8 + vi1_1],
+            )
+            placeholder_red_temp_v0_rf[ax0, vi1_1] = 
v_placeholder_red_temp_v0_rf
+            placeholder_red_temp_v1_rf[ax0, vi1_1] = 
v_placeholder_red_temp_v1_rf
+    for i0, i1_1 in T.grid(1, 8):
+        with T.block("placeholder_red_temp"):
+            vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
+            T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], 
placeholder_red_temp_v1_rf[ax0, vi1_1])
+            T.writes(placeholder_red_temp_v0[ax0], 
placeholder_red_temp_v1[ax0])
+            with T.init():
+                placeholder_red_temp_v0[ax0] = -1
+                placeholder_red_temp_v1[ax0] = 2147483647
+            v_placeholder_red_temp_v0: T.int32 = T.Select(
+                placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, 
vi1_1]
+                or placeholder_red_temp_v1[ax0] == 
placeholder_red_temp_v1_rf[ax0, vi1_1]
+                and placeholder_red_temp_v0[ax0] < 
placeholder_red_temp_v0_rf[ax0, vi1_1],
+                placeholder_red_temp_v0[ax0],
+                placeholder_red_temp_v0_rf[ax0, vi1_1],
+            )
+            v_placeholder_red_temp_v1: T.int32 = T.Select(
+                placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, 
vi1_1],
+                placeholder_red_temp_v1[ax0],
+                placeholder_red_temp_v1_rf[ax0, vi1_1],
+            )
+            placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
+            placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
+    for i0 in T.serial(1):
+        with T.block("placeholder_red"):
+            ax0 = T.axis.spatial(1, i0)
+            T.reads(placeholder_red_temp_v0[ax0])
+            T.writes(placeholder_red[ax0])
+            placeholder_red[ax0] = placeholder_red_temp_v0[ax0]
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
@@ -1490,5 +1610,35 @@ def 
test_reduction_rfactor_argmax_init_buffer_not_match():
         s.rfactor(ki, 1)
 
 
+def test_reduction_rfactor_topi_argmax():
+    A = te.placeholder((1, 32), dtype="int32")
+    B = topi.argmax(A, axis=1)
+    argmax_topi = te.create_prim_func([A, B])
+    s = tir.Schedule(argmax_topi, debug_mask="all")
+    argmax = s.get_block("placeholder_red_temp")
+    _, k = s.get_loops(argmax)
+    _, ki = s.split(k, [None, 8])
+    rf_block = s.rfactor(ki, 1)
+    tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor)
+    assert 
s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
+    assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp")))
+    verify_trace_roundtrip(s, mod=argmax_topi)
+
+
+def test_reduction_rfactor_topi_argmin():
+    A = te.placeholder((1, 32), dtype="int32")
+    B = topi.argmin(A, axis=1)
+    argmin_topi = te.create_prim_func([A, B])
+    s = tir.Schedule(argmin_topi, debug_mask="all")
+    argmin = s.get_block("placeholder_red_temp")
+    _, k = s.get_loops(argmin)
+    _, ki = s.split(k, [None, 8])
+    rf_block = s.rfactor(ki, 1)
+    tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor)
+    assert 
s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
+    assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp")))
+    verify_trace_roundtrip(s, mod=argmin_topi)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to