This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2cae905a72 [TIR] Support pattern matching argmax/argmin generated by
TOPI (#12827)
2cae905a72 is described below
commit 2cae905a727930eaaeb59085393eef1e1421fc20
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Sep 16 21:11:31 2022 -0400
[TIR] Support pattern matching argmax/argmin generated by TOPI (#12827)
This PR introduces two reducers to TIR reduction part, so that rfactor and
cross-thread reduction can be applied to those functions who contains
argmax/argmin computation generated by TOPI.
---
src/tir/schedule/primitive/reduction.cc | 134 +++++++++++-------
tests/python/unittest/test_tir_schedule_rfactor.py | 156 ++++++++++++++++++++-
2 files changed, 233 insertions(+), 57 deletions(-)
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index 2dc47fa15b..dd2bcf727c 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -297,60 +297,86 @@ StmtSRef DecomposeReduction(ScheduleState self, const
StmtSRef& block_sref,
*/
struct ReducerRegistry {
ReducerRegistry()
- : reducer_getters{CreateReducerGetter(
- /*n_buffers=*/1,
- [](const Array<Var>& x, const Array<Var>& y) {
- return Array<PrimExpr>{x[0] + y[0]};
- },
- [](const Array<PrimExpr>& values) {
- return
Array<PrimExpr>{make_const(values[0]->dtype, 0)};
- }),
- CreateReducerGetter(
- /*n_buffers=*/1,
- [](const Array<Var>& x, const Array<Var>& y) {
- return Array<PrimExpr>{x[0] * y[0]};
- },
- [](const Array<PrimExpr>& values) {
- return
Array<PrimExpr>{make_const(values[0]->dtype, 1)};
- }),
- CreateReducerGetter(
- /*n_buffers=*/1,
- [](const Array<Var>& x, const Array<Var>& y) {
- return Array<PrimExpr>{min(x[0], y[0])};
- },
- [](const Array<PrimExpr>& values) {
- return
Array<PrimExpr>{max_value(values[0]->dtype)};
- }),
- CreateReducerGetter(
- /*n_buffers=*/1,
- [](const Array<Var>& x, const Array<Var>& y) {
- return Array<PrimExpr>{max(x[0], y[0])};
- },
- [](const Array<PrimExpr>& values) {
- return
Array<PrimExpr>{min_value(values[0]->dtype)};
- }),
- CreateReducerGetter(
- /*n_buffers=*/2,
- [](const Array<Var>& x, const Array<Var>& y) {
- PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
- PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
- return Array<PrimExpr>{idx, val};
- },
- [](const Array<PrimExpr>& values) {
- return
Array<PrimExpr>{make_const(values[0]->dtype, -1),
-
min_value(values[1]->dtype)};
- }),
- CreateReducerGetter(
- /*n_buffers=*/2,
- [](const Array<Var>& x, const Array<Var>& y) {
- PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
- PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
- return Array<PrimExpr>{idx, val};
- },
- [](const Array<PrimExpr>& values) {
- return
Array<PrimExpr>{make_const(values[0]->dtype, -1),
-
max_value(values[1]->dtype)};
- })} {}
+ : reducer_getters{
+ CreateReducerGetter(
+ /*n_buffers=*/1,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ return Array<PrimExpr>{x[0] + y[0]};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/1,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ return Array<PrimExpr>{x[0] * y[0]};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/1,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ return Array<PrimExpr>{min(x[0], y[0])};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{max_value(values[0]->dtype)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/1,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ return Array<PrimExpr>{max(x[0], y[0])};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{min_value(values[0]->dtype)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/2,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
+ PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
+ return Array<PrimExpr>{idx, val};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+ min_value(values[1]->dtype)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/2,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ PrimExpr idx =
+ Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]),
less(x[0], y[0]))),
+ x[0], y[0]);
+ PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]);
+ return Array<PrimExpr>{idx, val};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+ min_value(values[1]->dtype)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/2,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
+ PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
+ return Array<PrimExpr>{idx, val};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+ max_value(values[1]->dtype)};
+ }),
+ CreateReducerGetter(
+ /*n_buffers=*/2,
+ [](const Array<Var>& x, const Array<Var>& y) {
+ PrimExpr idx = Select(
+ Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0],
y[0]))), x[0], y[0]);
+ PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]);
+ return Array<PrimExpr>{idx, val};
+ },
+ [](const Array<PrimExpr>& values) {
+ return Array<PrimExpr>{make_const(values[0]->dtype, -1),
+ max_value(values[1]->dtype)};
+ })} {}
static void RegisterReducer(
int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)>
combiner_getter,
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py
b/tests/python/unittest/test_tir_schedule_rfactor.py
index f6db79f3ed..964fe772d8 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -15,12 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
-import sys
-
import pytest
import tvm
import tvm.testing
-from tvm import tir
+from tvm import te, tir, topi
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
@@ -1133,6 +1131,128 @@ def argmin_split_rfactor(
argmin_v1[i] = v_argmin_v1
[email protected]_func
+def argmax_topi_rfactor(
+ placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1,
"int32"]
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
+ placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
+ placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
+ placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
+ for i0, i1_0, i1_1 in T.grid(1, 4, 8):
+ with T.block("placeholder_red_temp_rf"):
+ vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
+ T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
+ T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1],
placeholder_red_temp_v1_rf[ax0, vi1_1])
+ with T.init():
+ placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
+ placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648
+ v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
+ placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0,
vi1_0 * 8 + vi1_1]
+ or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0,
vi1_0 * 8 + vi1_1]
+ and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
+ placeholder_red_temp_v0_rf[ax0, vi1_1],
+ vi1_0 * 8 + vi1_1,
+ )
+ v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
+ placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0,
vi1_0 * 8 + vi1_1],
+ placeholder_red_temp_v1_rf[ax0, vi1_1],
+ placeholder[ax0, vi1_0 * 8 + vi1_1],
+ )
+ placeholder_red_temp_v0_rf[ax0, vi1_1] =
v_placeholder_red_temp_v0_rf
+ placeholder_red_temp_v1_rf[ax0, vi1_1] =
v_placeholder_red_temp_v1_rf
+ for i0, i1_1 in T.grid(1, 8):
+ with T.block("placeholder_red_temp"):
+ vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
+ T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1],
placeholder_red_temp_v1_rf[ax0, vi1_1])
+ T.writes(placeholder_red_temp_v0[ax0],
placeholder_red_temp_v1[ax0])
+ with T.init():
+ placeholder_red_temp_v0[ax0] = -1
+ placeholder_red_temp_v1[ax0] = -2147483648
+ v_placeholder_red_temp_v0: T.int32 = T.Select(
+ placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0,
vi1_1]
+ or placeholder_red_temp_v1[ax0] ==
placeholder_red_temp_v1_rf[ax0, vi1_1]
+ and placeholder_red_temp_v0[ax0] <
placeholder_red_temp_v0_rf[ax0, vi1_1],
+ placeholder_red_temp_v0[ax0],
+ placeholder_red_temp_v0_rf[ax0, vi1_1],
+ )
+ v_placeholder_red_temp_v1: T.int32 = T.Select(
+ placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0,
vi1_1],
+ placeholder_red_temp_v1[ax0],
+ placeholder_red_temp_v1_rf[ax0, vi1_1],
+ )
+ placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
+ placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
+ for i0 in T.serial(1):
+ with T.block("placeholder_red"):
+ ax0 = T.axis.spatial(1, i0)
+ T.reads(placeholder_red_temp_v0[ax0])
+ T.writes(placeholder_red[ax0])
+ placeholder_red[ax0] = placeholder_red_temp_v0[ax0]
+
+
[email protected]_func
+def argmin_topi_rfactor(
+ placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1,
"int32"]
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
+ placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
+ placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
+ placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
+ for i0, i1_0, i1_1 in T.grid(1, 4, 8):
+ with T.block("placeholder_red_temp_rf"):
+ vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
+ T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
+ T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1],
placeholder_red_temp_v1_rf[ax0, vi1_1])
+ with T.init():
+ placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
+ placeholder_red_temp_v1_rf[ax0, vi1_1] = 2147483647
+ v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
+ placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0,
vi1_0 * 8 + vi1_1]
+ or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0,
vi1_0 * 8 + vi1_1]
+ and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
+ placeholder_red_temp_v0_rf[ax0, vi1_1],
+ vi1_0 * 8 + vi1_1,
+ )
+ v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
+ placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0,
vi1_0 * 8 + vi1_1],
+ placeholder_red_temp_v1_rf[ax0, vi1_1],
+ placeholder[ax0, vi1_0 * 8 + vi1_1],
+ )
+ placeholder_red_temp_v0_rf[ax0, vi1_1] =
v_placeholder_red_temp_v0_rf
+ placeholder_red_temp_v1_rf[ax0, vi1_1] =
v_placeholder_red_temp_v1_rf
+ for i0, i1_1 in T.grid(1, 8):
+ with T.block("placeholder_red_temp"):
+ vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
+ T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1],
placeholder_red_temp_v1_rf[ax0, vi1_1])
+ T.writes(placeholder_red_temp_v0[ax0],
placeholder_red_temp_v1[ax0])
+ with T.init():
+ placeholder_red_temp_v0[ax0] = -1
+ placeholder_red_temp_v1[ax0] = 2147483647
+ v_placeholder_red_temp_v0: T.int32 = T.Select(
+ placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0,
vi1_1]
+ or placeholder_red_temp_v1[ax0] ==
placeholder_red_temp_v1_rf[ax0, vi1_1]
+ and placeholder_red_temp_v0[ax0] <
placeholder_red_temp_v0_rf[ax0, vi1_1],
+ placeholder_red_temp_v0[ax0],
+ placeholder_red_temp_v0_rf[ax0, vi1_1],
+ )
+ v_placeholder_red_temp_v1: T.int32 = T.Select(
+ placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0,
vi1_1],
+ placeholder_red_temp_v1[ax0],
+ placeholder_red_temp_v1_rf[ax0, vi1_1],
+ )
+ placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
+ placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
+ for i0 in T.serial(1):
+ with T.block("placeholder_red"):
+ ax0 = T.axis.spatial(1, i0)
+ T.reads(placeholder_red_temp_v0[ax0])
+ T.writes(placeholder_red[ax0])
+ placeholder_red[ax0] = placeholder_red_temp_v0[ax0]
+
+
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
@@ -1490,5 +1610,35 @@ def
test_reduction_rfactor_argmax_init_buffer_not_match():
s.rfactor(ki, 1)
+def test_reduction_rfactor_topi_argmax():
+ A = te.placeholder((1, 32), dtype="int32")
+ B = topi.argmax(A, axis=1)
+ argmax_topi = te.create_prim_func([A, B])
+ s = tir.Schedule(argmax_topi, debug_mask="all")
+ argmax = s.get_block("placeholder_red_temp")
+ _, k = s.get_loops(argmax)
+ _, ki = s.split(k, [None, 8])
+ rf_block = s.rfactor(ki, 1)
+ tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor)
+ assert
s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
+ assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp")))
+ verify_trace_roundtrip(s, mod=argmax_topi)
+
+
+def test_reduction_rfactor_topi_argmin():
+ A = te.placeholder((1, 32), dtype="int32")
+ B = topi.argmin(A, axis=1)
+ argmin_topi = te.create_prim_func([A, B])
+ s = tir.Schedule(argmin_topi, debug_mask="all")
+ argmin = s.get_block("placeholder_red_temp")
+ _, k = s.get_loops(argmin)
+ _, ki = s.split(k, [None, 8])
+ rf_block = s.rfactor(ki, 1)
+ tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor)
+ assert
s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
+ assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp")))
+ verify_trace_roundtrip(s, mod=argmin_topi)
+
+
if __name__ == "__main__":
tvm.testing.main()