This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 937a14f [TIR][Analysis] Add SuggestIndexMap for layout rewriting
(#10732)
937a14f is described below
commit 937a14f07f981b8bd71eda480890315123abe67c
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Mar 25 11:19:46 2022 -0700
[TIR][Analysis] Add SuggestIndexMap for layout rewriting (#10732)
This PR added an analysis function `SuggestIndexMap` to analyze buffer
access pattern and suggest index map for layout transformations.
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
---
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/function.py | 15 ++
python/tvm/tir/schedule/__init__.py | 2 +
python/tvm/tir/schedule/analysis.py | 58 ++++++
src/tir/ir/index_map.cc | 2 +
src/tir/schedule/analysis.h | 14 ++
src/tir/schedule/analysis/layout.cc | 212 +++++++++++++++++++++
.../python/unittest/test_tir_schedule_analysis.py | 107 +++++++++++
8 files changed, 411 insertions(+), 1 deletion(-)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 17f9aa3..2d201bb 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -42,7 +42,7 @@ from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-from .function import PrimFunc, TensorIntrin
+from .function import PrimFunc, TensorIntrin, IndexMap
from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any,
min_value, max_value, trace
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 98af3b4..643bbca 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -295,3 +295,18 @@ class IndexMap(Object):
final_indices = mapping_function(*args)
return IndexMap(args, final_indices)
+
+ def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
+ """Apply the index map to a set of indices
+
+ Parameters
+ ----------
+ indices : List[PriExpr]
+ The indices to be mapped
+
+ Returns
+ -------
+ result : List[PrimExpr]
+ The mapped indices
+ """
+ return _ffi_api.IndexMapMapIndices(self, indices)
diff --git a/python/tvm/tir/schedule/__init__.py
b/python/tvm/tir/schedule/__init__.py
index 5f0e169..66ac7b9 100644
--- a/python/tvm/tir/schedule/__init__.py
+++ b/python/tvm/tir/schedule/__init__.py
@@ -22,3 +22,5 @@ from .instruction import Instruction, InstructionKind
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace
+
+from . import analysis
diff --git a/python/tvm/tir/schedule/analysis.py
b/python/tvm/tir/schedule/analysis.py
new file mode 100644
index 0000000..f2fb7c4
--- /dev/null
+++ b/python/tvm/tir/schedule/analysis.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Analysis used in TensorIR scheduling"""
+from typing import List, Optional
+
+from ..buffer import Buffer
+from ..stmt import For
+from ..expr import PrimExpr
+from ..function import IndexMap
+
+from . import _ffi_api
+
+
+def suggest_index_map(
+ buffer: Buffer,
+ indices: List[PrimExpr],
+ loops: List[For],
+ predicate: PrimExpr,
+) -> Optional[IndexMap]:
+ """Provided the access pattern to a buffer, suggest one of the possible
layout
+ transformation to maximize the locality of the access pattern.
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer to be transformed.
+ indices : List[PrimExpr]
+ The access pattern to the buffer.
+ loops : List[For]
+ The loops above the buffer.
+ predicate : PrimExpr
+ The predicate of the access.
+
+ Returns
+ -------
+ index_map : Optional[IndexMap]
+ The suggested index map. None if no transformation is suggested.
+ """
+ return _ffi_api.SuggestIndexMap( # type: ignore # pylint:
disable=no-member
+ buffer,
+ indices,
+ loops,
+ predicate,
+ )
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index b58b602..1e54f54 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -201,5 +201,7 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
return IndexMap(initial_indices, final_indices);
});
+TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method<IndexMap>(&IndexMapNode::MapIndices);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 4deadcf..e74b9ea 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -21,6 +21,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
+#include <tvm/tir/index_map.h>
#include <tvm/tir/schedule/state.h>
#include <tuple>
@@ -521,6 +522,19 @@ bool CanReverseComputeAt(const ScheduleState& self, const
StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);
/*!
+ * \brief Provided the access pattern to a buffer, suggest one of the possible
layout
+ * transformation to minimize the locality of the access pattern.
+ * \param buffer The buffer to be transformed
+ * \param indices The access pattern to the buffer
+ * \param loops The loops above the buffer
+ * \param predicate The predicate of the access
+ * \param analyzer Arithmetic analyzer
+ */
+Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const
Array<PrimExpr>& indices,
+ const Array<For>& loops, const PrimExpr&
predicate,
+ arith::Analyzer* analyzer);
+
+/*!
* \brief Checks if the given AST contains the specific operators
* \param stmt The AST statement to be checked
* \param ops The list of operators to be checked
diff --git a/src/tir/schedule/analysis/layout.cc
b/src/tir/schedule/analysis/layout.cc
new file mode 100644
index 0000000..144b3a5
--- /dev/null
+++ b/src/tir/schedule/analysis/layout.cc
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Calculate the strides of the buffer
+ * \param buffer The buffer
+ * \return The strides
+ */
+Array<PrimExpr> GetStrides(const Buffer& buffer) {
+ if (!buffer->strides.empty()) {
+ ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
+ return buffer->strides;
+ }
+ int ndim = buffer->shape.size();
+ if (ndim == 0) {
+ return {};
+ }
+ Array<PrimExpr> strides(ndim, PrimExpr{nullptr});
+ PrimExpr stride = make_const(buffer->DefaultIndexType(), 1);
+ for (int i = ndim - 1; i >= 0; --i) {
+ strides.Set(i, stride);
+ stride = stride * buffer->shape[i];
+ }
+ return strides;
+}
+
+/*!
+ * \brief Auxiliary class that collects the IterSplitExpr in the indexing
pattern
+ * to help decision making in layout transformation
+ */
+class SplitExprCollector {
+ public:
+ /*!
+ * \brief The corresponding IterSplitExpr, simplified for our case
+ * The pattern is `source // lower_factor % extent * scale`
+ */
+ struct SplitExpr {
+ /*! \brief The source variable */
+ Var source;
+ /*! \brief The lower factor of the split expression */
+ int64_t lower_factor;
+ /*! \brief The extent of the split expression */
+ int64_t extent;
+ };
+
+ /*!
+ * \brief Collect the split expressions in the indexing pattern
+ * \param index The indexing pattern
+ * \param input_iters The input iterators' domain
+ * \param predicate The predicate of the affine map
+ * \param require_bijective Whether the affine map is required to be
bijective
+ * \param analyzer The analyzer
+ * \return The collected split expressions
+ */
+ static std::vector<SplitExpr> Collect(const PrimExpr& index,
+ const Map<Var, Range>& input_iters, //
+ const PrimExpr& predicate, //
+ bool require_bijective, //
+ arith::Analyzer* analyzer) {
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
+ Array<arith::IterSumExpr> iter_sum_exprs = arith::DetectIterMap(
+ {analyzer->Simplify(index)}, input_iters, predicate,
require_bijective, analyzer, diag_ctx);
+ if (iter_sum_exprs.empty()) {
+ return {};
+ }
+ ICHECK_EQ(iter_sum_exprs.size(), 1);
+ if (iter_sum_exprs[0]->args.size() == 0) {
+ return {};
+ }
+ SplitExprCollector collector;
+ collector.Visit(iter_sum_exprs[0]);
+ if (collector.failed_) {
+ return {};
+ }
+ return std::move(collector.exprs_);
+ }
+
+ private:
+ void Visit(const arith::IterSplitExpr& expr) {
+ if (const auto* var = expr->source->source.as<tir::VarNode>()) {
+ const int64_t* lower_factor = as_const_int(expr->lower_factor);
+ const int64_t* extent = as_const_int(expr->extent);
+ if (lower_factor == nullptr || extent == nullptr) {
+ failed_ = true;
+ return;
+ }
+ exprs_.push_back(SplitExpr{GetRef<Var>(var), *lower_factor, *extent});
+ } else if (const auto* iter_sum_expr =
expr->source->source.as<arith::IterSumExprNode>()) {
+ Visit(GetRef<arith::IterSumExpr>(iter_sum_expr));
+ } else {
+ ICHECK(false) << "Unexpected type: " <<
expr->source->source->GetTypeKey();
+ }
+ }
+
+ void Visit(const arith::IterSumExpr& expr) {
+ for (const arith::IterSplitExpr& arg : expr->args) {
+ Visit(arg);
+ }
+ }
+
+ /*! \brief Whether the analysis failed */
+ bool failed_ = false;
+ /*! \brief The collected split expressions */
+ std::vector<SplitExpr> exprs_;
+};
+
+Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const
Array<PrimExpr>& indices,
+ const Array<For>& loops, const PrimExpr&
predicate,
+ arith::Analyzer* analyzer) {
+ int ndim = buffer->shape.size();
+ int n_loops = loops.size();
+ // Step 1. Collect the domains and indices of loop variables
+ Map<Var, Range> input_iters;
+ std::unordered_map<const VarNode*, int> var2id;
+ var2id.reserve(n_loops);
+ for (int i = 0; i < n_loops; ++i) {
+ const For& loop = loops[i];
+ input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent));
+ var2id.emplace(loop->loop_var.get(), i);
+ }
+ // Step 2. Calculate a functor that flattens a multi-dimensional index
+ auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype =
buffer->DefaultIndexType()](
+ const Array<PrimExpr>& indices) -> PrimExpr {
+ PrimExpr flatten_index = make_const(dtype, 0);
+ for (int i = 0; i < ndim; ++i) {
+ flatten_index = flatten_index + strides[i] * indices[i];
+ }
+ return flatten_index;
+ };
+ // Step 3. Detect the IterSplitExpr of the indexing pattern
+ std::vector<SplitExprCollector::SplitExpr> split_exprs =
SplitExprCollector::Collect(
+ /*index=*/f_flatten_index(indices), input_iters, predicate,
+ /*require_bijective=*/false, analyzer);
+ if (split_exprs.empty()) {
+ return NullOpt;
+ }
+ // Step 4. Sort the order of the split expressions
+ std::vector<int> order(split_exprs.size(), 0);
+ std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; });
+ std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int
_b) -> bool {
+ const SplitExprCollector::SplitExpr& a = split_exprs[_a];
+ const SplitExprCollector::SplitExpr& b = split_exprs[_b];
+ int a_var_id = var2id.at(a.source.get());
+ int b_var_id = var2id.at(b.source.get());
+ if (a_var_id != b_var_id) {
+ return a_var_id < b_var_id;
+ }
+ return a.lower_factor > b.lower_factor;
+ });
+ // Step 5. Create the indexing mapping
+ auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
+ split_exprs = std::move(split_exprs), //
+ order = std::move(order), //
+ shape = buffer->shape, //
+ analyzer //
+ ](Array<Var> indices) -> Array<PrimExpr> {
+ ICHECK_EQ(indices.size(), shape.size());
+ for (int i = 0, n = indices.size(); i < n; ++i) {
+ analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
+ }
+ PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
+ int ndim = split_exprs.size();
+ // Step 5.1. Split the flattened index according to `split_exprs`
+ std::vector<PrimExpr> split;
+ split.reserve(ndim);
+ for (int i = ndim - 1; i >= 0; --i) {
+ index = analyzer->Simplify(index);
+ int64_t extent = split_exprs[i].extent;
+ split.push_back(analyzer->Simplify(floormod(index, extent)));
+ index = floordiv(index, extent);
+ }
+ std::reverse(split.begin(), split.end());
+ // Step 5.2. Reorder the indexing pattern according to `order`
+ Array<PrimExpr> results;
+ results.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ results.push_back(split[order[i]]);
+ }
+ return results;
+ };
+ return IndexMap::FromFunc(ndim, f_alter_layout);
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
+ .set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Array<For>
loops,
+ PrimExpr predicate) {
+ arith::Analyzer analyzer;
+ return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer);
+ });
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py
b/tests/python/unittest/test_tir_schedule_analysis.py
new file mode 100644
index 0000000..760b412
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import List
+
+from tvm.tir import (
+ Evaluate,
+ For,
+ ForKind,
+ IndexMap,
+ Var,
+ decl_buffer,
+ floordiv,
+ floormod,
+)
+from tvm.tir.analysis import expr_deep_equal
+from tvm.tir.schedule.analysis import suggest_index_map
+
+
+def _make_vars(*args: str) -> List[Var]:
+ return [Var(arg, dtype="int32") for arg in args]
+
+
+def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]:
+ assert len(loop_vars) == len(extents)
+ return [
+ For(
+ loop_var=loop_var,
+ min_val=0,
+ extent=extent,
+ kind=ForKind.SERIAL,
+ body=Evaluate(0),
+ )
+ for loop_var, extent in zip(loop_vars, extents)
+ ]
+
+
+def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:
+ iters_1 = map1.map_indices(map2.initial_indices)
+ iters_2 = map2.final_indices
+ assert len(iters_1) == len(iters_2)
+ for iter1, iter2 in zip(iters_1, iters_2):
+ assert expr_deep_equal(iter1, iter2)
+
+
+def test_suggest_index_map_simple():
+ i, j = _make_vars("i", "j")
+ index_map = suggest_index_map(
+ buffer=decl_buffer(shape=[8, 256]),
+ indices=[
+ floordiv(i, 16) * 4 + floordiv(j, 16),
+ floormod(i, 16) * 16 + floormod(j, 16),
+ ],
+ loops=_make_loops(
+ loop_vars=[i, j],
+ extents=[32, 64],
+ ),
+ predicate=True,
+ )
+ expected_index_map = IndexMap.from_func(
+ lambda x, y: [
+ floordiv(x, 4),
+ floordiv(y, 16),
+ floormod(x, 4),
+ floormod(y, 16),
+ ],
+ )
+ _assert_equal_index_map(index_map, expected_index_map)
+
+
+def test_suggest_index_map_bijective():
+ i, j = _make_vars("i", "j")
+ index_map = suggest_index_map(
+ buffer=decl_buffer(shape=[8]),
+ indices=[floormod(j, 4) * 2 + i],
+ loops=_make_loops(
+ loop_vars=[i, j],
+ extents=[2, 32],
+ ),
+ predicate=True,
+ )
+ expected_index_map = IndexMap.from_func(
+ lambda x: [
+ floormod(x, 2),
+ floordiv(x, 2),
+ ],
+ )
+ _assert_equal_index_map(index_map, expected_index_map)
+
+
+if __name__ == "__main__":
+ test_suggest_index_map_simple()
+ test_suggest_index_map_bijective()