This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9bba7580b0 [TIR, analysis] Add GetAutoTensorizeMappingInfo to generate
transforms for auto tensorization (#11740)
9bba7580b0 is described below
commit 9bba7580b0dcaea4963bd6b35df0bf6bf867b8ff
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Jun 18 18:04:52 2022 -0700
[TIR, analysis] Add GetAutoTensorizeMappingInfo to generate transforms for
auto tensorization (#11740)
This PR added a utility function `GetAutoTensorizeMappingInfo` to propose
mapping from workload block iters to the iters in the tensor intrin. An example
usage is conv2d, where the computation block has more iters than the matmul
tensor intrin.
---
python/tvm/meta_schedule/testing/te_workload.py | 68 +++++
python/tvm/tir/schedule/analysis.py | 34 +++
src/tir/schedule/analysis.h | 50 ++++
src/tir/schedule/analysis/analysis.cc | 275 +++++++++++++++++++--
src/tir/schedule/ir_comparator.cc | 126 +++++++++-
src/tir/schedule/ir_comparator.h | 52 +++-
.../python/unittest/test_tir_schedule_analysis.py | 56 ++++-
7 files changed, 630 insertions(+), 31 deletions(-)
diff --git a/python/tvm/meta_schedule/testing/te_workload.py
b/python/tvm/meta_schedule/testing/te_workload.py
index 52f5f49b0a..28a2df628c 100644
--- a/python/tvm/meta_schedule/testing/te_workload.py
+++ b/python/tvm/meta_schedule/testing/te_workload.py
@@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: #
pylint: disable=invalid-
return (a, b)
+def conv2d_nhwc_f16( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+):
+ inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16")
+ weight = te.placeholder(
+ (kernel_size, kernel_size, CI // groups, CO), name="weight",
dtype="float16"
+ )
+ batch_size, in_h, in_w, _ = inputs.shape
+ k_h, k_w, channel_per_group, out_channel = weight.shape
+ out_channel_per_group = out_channel // groups
+
+ out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+ out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+ rh = te.reduce_axis((0, k_h), name="rh")
+ rw = te.reduce_axis((0, k_w), name="rw")
+ rc = te.reduce_axis((0, channel_per_group), name="rc")
+
+ padded = topi.nn.pad(inputs, [0, padding, padding, 0])
+ output = te.compute(
+ (batch_size, out_h, out_w, out_channel),
+ lambda n, h, w, co: te.sum(
+ (
+ tir.Cast(
+ value=padded[
+ n,
+ h * stride + rh * dilation,
+ w * stride + rw * dilation,
+ co // out_channel_per_group * channel_per_group + rc,
+ ],
+ dtype="float32",
+ )
+ * tir.Cast(value=weight[rh, rw, rc, co], dtype="float32")
+ ),
+ axis=[rh, rw, rc],
+ ),
+ name="conv2d_nhwc",
+ )
+ return (inputs, weight, output)
+
+
+def batch_matmul_nkkm_f16( # pylint: disable=invalid-name,missing-docstring
+ B: int,
+ N: int,
+ M: int,
+ K: int,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ x = te.placeholder((B, N, K), name="X", dtype="float16")
+ y = te.placeholder((B, K, M), name="Y", dtype="float16")
+ k = te.reduce_axis((0, K), name="k")
+ z = te.compute( # pylint: disable=invalid-name
+ (B, N, M),
+ lambda b, i, j: te.sum(
+ tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]),
axis=[k]
+ ),
+ name="Z",
+ )
+ return (x, y, z)
+
+
def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
workload_func, params = CONFIGS[name]
return te.create_prim_func(workload_func(*params[idx])) # type: ignore
diff --git a/python/tvm/tir/schedule/analysis.py
b/python/tvm/tir/schedule/analysis.py
index 71ff024217..cdb4aa9cfa 100644
--- a/python/tvm/tir/schedule/analysis.py
+++ b/python/tvm/tir/schedule/analysis.py
@@ -87,3 +87,37 @@ def get_tensorize_loop_mapping(
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type:
ignore
+
+
+@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
+class AutoTensorizeMappingInfo(Object):
+ """Necessary information used to perform transformations for
tensorization."""
+
+
+def get_auto_tensorize_mapping_info(
+ sch: Schedule, block: BlockRV, desc_func: PrimFunc
+) -> Optional[AutoTensorizeMappingInfo]:
+ """Get mapping info between a target block and an intrinsic description
including layout
+ transformations to apply.
+
+ Parameters
+ ----------
+ sch : Schedule
+ The schedule to be tensorized
+ block : BlockRV
+ The compute block for auto tensorization
+ desc_func : PrimFunc
+ The prim func describing the computation to be tensorized
+
+ Returns
+ -------
+ auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo]
+ AutoTensorizeMappingInfo structure if potential mappings found, None
otherwise.
+
+ Note
+ ----
+ Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can
be tensorized.
+ We will need to apply the suggested layout transformations and then match
against the tensor
+ intrinsics.
+ """
+ return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) #
type: ignore
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 5adc4f8f1b..b30cef829f 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -707,6 +707,56 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const
tir::ScheduleState& self,
const tir::StmtSRef&
block_sref,
const tir::PrimFunc&
desc_func);
+/*!\brief Necessary information used to perform transformations for
tensorization */
+class AutoTensorizeMappingInfoNode : public Object {
+ public:
+ /*! \brief Possible mappings to apply to block iters */
+ Array<IndexMap> mappings;
+
+ /* Additional information from AutoTensorizeComparator */
+
+ /*! \brief Mapping from LHS buffer to RHS buffer */
+ Map<Buffer, Buffer> lhs_buffer_map;
+ /*! \brief Buffer indices on RHS */
+ Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
+ /*! \brief Block iters on LHS */
+ Array<IterVar> lhs_iters;
+ /*! \brief Block iters on RHS */
+ Array<IterVar> rhs_iters;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("mappings", &mappings);
+ v->Visit("lhs_buffer_map", &lhs_buffer_map);
+ v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
+ v->Visit("lhs_iters", &lhs_iters);
+ v->Visit("rhs_iters", &rhs_iters);
+ }
+
+ static constexpr const char* _type_key =
"tir.schedule.AutoTensorizeMappingInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object);
+};
+
+class AutoTensorizeMappingInfo : public ObjectRef {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo,
ObjectRef,
+ AutoTensorizeMappingInfoNode);
+};
+
+/*!
+ * \brief Get mapping info between a target block and an intrinsic description
including layout
+ * transformations to apply.
+ * \param self The schedule state
+ * \param block_sref The compute block for auto tensorization
+ * \param desc_func The prim func describing the computation to be tensorized
+ * \return AutoTensorizeMappingInfo structure if a potential mapping is found,
NullOpt otherwise.
+ * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the
block can be tensorized.
+ * We will need to apply the suggested layout transformations and then match
against the tensor
+ * intrinsics.
+ */
+Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const
ScheduleState& self,
+ const StmtSRef&
block_sref,
+ const PrimFunc&
desc_func);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 7def8b8674..3ee1ed28b8 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -19,6 +19,7 @@
#include <tvm/runtime/container/optional.h>
#include <tvm/tir/expr.h>
+#include "../ir_comparator.h"
#include "../utils.h"
namespace tvm {
@@ -2085,39 +2086,60 @@ bool NeedsRFactorOrCrossThreadReduction(const
tir::ScheduleState& self, //
TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
-Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
- const tir::StmtSRef&
block_sref,
- const tir::PrimFunc&
desc_func) {
- arith::Analyzer analyzer;
- const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
- // Step 1. Analyze desc_func, extract its block, loops and loop vars
- const tir::BlockRealizeNode* desc_block = nullptr;
+/*! \brief Auxiliary data structure of information extracted from tensor
intrin description */
+struct TensorIntrinDescInfo {
+ /*! \brief The block of the description function, which is the (unique)
direct child of the root
+ * block.
+ */
+ const BlockRealizeNode* desc_block = nullptr;
+ /*! \brief The loops of the description function, in the order from outer
loops to inner ones. */
std::vector<const tir::ForNode*> desc_loops;
+ /*! \brief The loop variables. */
std::unordered_set<const tir::VarNode*> desc_loop_vars;
- const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+};
+
+/*!
+ * \brief Extract auxilary information from the tensor intrin description.
+ * \param analyze The arithmetic analyzer
+ * \param desc_func The description PrimFunc
+ * \return The auxilary information
+ */
+TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer,
+ const PrimFunc& desc_func) {
+ TensorIntrinDescInfo info;
+ const auto* desc_scope_realize = desc_func->body.as<BlockRealizeNode>();
ICHECK(desc_scope_realize);
{
- auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
- &analyzer](const ObjectRef& obj) -> bool {
+ auto f_visit = [&](const ObjectRef& obj) -> bool {
// Extract the block
- if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
- desc_block = block;
+ if (const auto* block = obj.as<BlockRealizeNode>()) {
+ info.desc_block = block;
return false;
}
- // Extract loops
- if (const auto* loop = obj.as<tir::ForNode>()) {
- desc_loops.push_back(loop);
- desc_loop_vars.insert(loop->loop_var.get());
- if (!analyzer.CanProve(loop->min == 0)) {
+ // Extract the loops
+ if (const auto* loop = obj.as<ForNode>()) {
+ info.desc_loops.push_back(loop);
+ info.desc_loop_vars.insert(loop->loop_var.get());
+ if (!analyzer->CanProve(loop->min == 0)) {
return false;
}
}
return true;
};
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
- std::reverse(desc_loops.begin(), desc_loops.end());
- ICHECK(desc_block);
+ std::reverse(info.desc_loops.begin(), info.desc_loops.end());
+ ICHECK(info.desc_block);
}
+ return info;
+}
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+ const tir::StmtSRef&
block_sref,
+ const tir::PrimFunc&
desc_func) {
+ arith::Analyzer analyzer;
+ const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+ // Step 1. Analyze desc_func, extract its block, loops and loop vars
+ TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer,
desc_func);
// Step 2. Collect loops from block_sref
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block,
scope_sref);
@@ -2138,6 +2160,9 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const
tir::ScheduleState& self,
std::reverse(block_loops.begin(), block_loops.end());
}
// Step 3. Map from block loops to desc block loops
+ const std::vector<const ForNode*>& desc_loops = desc_info.desc_loops;
+ const std::unordered_set<const VarNode*>& desc_loop_vars =
desc_info.desc_loop_vars;
+ const BlockRealizeNode* desc_block = desc_info.desc_block;
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
const int n_block_vars = block->iter_values.size();
const int n_desc_vars = desc_block->iter_values.size();
@@ -2240,5 +2265,217 @@
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block),
desc_func);
});
+/******** Auto Tensorization ********/
+
+/*! \brief IndexMap proposer for layout transformation in auto tensorization.
*/
+class AutoTensorizeMappingProposer {
+ public:
+ static Array<IndexMap> ProposeMappings(const AutoTensorizeComparator*
extractor,
+ arith::Analyzer* analyzer) {
+ AutoTensorizeMappingProposer proposer(extractor, analyzer);
+ proposer.CollectFeasibleSet();
+ return proposer.ProposeAllFuseMapping();
+ }
+
+ private:
+ explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator*
extractor,
+ arith::Analyzer* analyzer)
+ : extractor_(extractor), analyzer_(analyzer) {}
+
+ using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+ void CollectFeasibleSet() {
+ // Collect the set of potential iter var mapping between the workload and
the tensor intrin.
+ // We analyze the appearance of each variable in the buffer indices of
each buffer on LHS and
+ // RHS. The appearance of a variable in the buffer indices is encoded as
bit-masks (BufferMask).
+ // Variables on the LHS and the RHS with the same bit-mask and the same
iter type are potential
+ // mappings.
+ //
+ // For example, consider the conv2d case. We will try to match the workload
+ // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] *
W[rh, rw, rc, c]
+ // against a matmul tensor intrin
+ // C[m, n] = sum_{k} A[m, k] * B[k, n]
+ // First we extract the correspondence of the buffers: conv2d <=> C, A <=>
X, B <=> W.
+ // Then for each variable, we extract the buffers where it is used for
indexing.
+ // Take the variable m on the RHS as an example. m is used to index buffer
A and C. On the LHS,
+ // we will find the variables used to index only the exact corresponding
buffers conv2d and X
+ // (the variable is not allowed to index other buffers). In this case, n,
h, w is used to index
+ // both buffer conv2d and W, and not in other buffers. Therefore, {n, h,
w} <=> m is a potential
+ // mapping.
+
+ // Note: the mapping is not unique when multiple variables on RHS has the
same bit-mask.
+ // This is currently not supported.
+
+ using BufferMask = std::vector<bool>;
+
+ // Step 1: Assign an index to each buffer in LHS and RHS
+ std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual>
rhs_buffer_index;
+ std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual>
lhs_buffer_index;
+ {
+ int i = 0;
+ for (const auto& kv : extractor_->rhs_buffer_map_) {
+ const Buffer& rhs_buffer = kv.first;
+ const Buffer& lhs_buffer = kv.second;
+ rhs_buffer_index[rhs_buffer] = i;
+ lhs_buffer_index[lhs_buffer] = i;
+ ++i;
+ }
+ }
+
+ // Step 2: Compute the buffer mask
+ ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size());
+ int num_buffers = rhs_buffer_index.size();
+ std::unordered_map<const VarNode*, std::vector<bool>> rhs_buffer_masks,
lhs_buffer_masks;
+ // helper function to initialize or update the buffer mask
+ auto update_mask = [&](const VarNode* var,
+ std::unordered_map<const VarNode*,
std::vector<bool>>* masks, int i) {
+ if (!masks->count(var)) {
+ (*masks)[var].resize(num_buffers);
+ }
+ (*masks)[var][i] = true;
+ };
+
+ for (const auto& it : extractor_->rhs_buffer_indices_map_) {
+ const Buffer& rhs_buffer = it.first;
+ for (const PrimExpr& rhs_index : it.second) {
+ if (const VarNode* var_node = rhs_index.as<VarNode>()) {
+ update_mask(var_node, &rhs_buffer_masks,
rhs_buffer_index.at(rhs_buffer));
+ } else {
+ LOG(FATAL) << "ValueError: Buffer index " << rhs_index
+ << " other that variables in tensor intrinsics is not
supported.";
+ }
+ }
+
+ auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer);
+ ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end());
+ const Buffer& lhs_buffer = lhs_buffer_it->second;
+ for (const PrimExpr& index :
extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) {
+ PreOrderVisit(index, [&](const ObjectRef& obj) -> bool {
+ if (const VarNode* var = obj.as<VarNode>()) {
+ update_mask(var, &lhs_buffer_masks,
lhs_buffer_index.at(lhs_buffer));
+ }
+ return true;
+ });
+ }
+ }
+
+ // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure
LHS and RHS vars
+ // have the same iter type.
+ std::unordered_map<BufferMask, VarSet> mask_to_rhs_vars;
+ for (const auto& kv : rhs_buffer_masks) {
+ const VarNode* rhs_var = kv.first;
+ const BufferMask& mask = kv.second;
+ mask_to_rhs_vars[mask].insert(GetRef<Var>(rhs_var));
+ }
+ std::unordered_map<const VarNode*, IterVarType> rhs_var_iter_type;
+ for (const auto& iter : extractor_->rhs_iters_) {
+ rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type);
+ }
+ for (const auto& iter : extractor_->lhs_iters_) {
+ auto& potential_mappings = lhs_feasible_vars_[iter->var];
+ VarSet rhs_candidates =
mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]];
+ std::copy_if(
+ rhs_candidates.begin(), rhs_candidates.end(),
+ std::inserter(potential_mappings, potential_mappings.begin()),
+ [&](const Var& var) { return rhs_var_iter_type.at(var.get()) ==
iter->iter_type; });
+ }
+ }
+
+ Array<IndexMap> ProposeAllFuseMapping() {
+ // Now we have calcuated potential mapping for each iter var on LHS. For
iters on LHS mapped to
+ // the same iter on RHS, they will be fused in the original order in LHS
block iters. We will
+ // generate IndexMap to represent such fusion on LHS. For example, if n,
h, w on LHS are mapped
+ // to the same iter var on RHS, we will produce index map `lambda n, h, w:
fuse(n, h, w)`, where
+ // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn
+
+ // the parameters of the result index map, each parameter corresponds to a
LHS iter
+ Array<Var> index_map_src;
+ // the outputs of the result index map
+ Array<PrimExpr> index_map_tgt;
+
+ // Step 1: Collect extents of LHS iters and prepare the initial indices of
the IndexMap
+ Map<Var, PrimExpr> lhs_iter_extents;
+ for (const auto& iter : extractor_->lhs_iters_) {
+ lhs_iter_extents.Set(iter->var, iter->dom->extent);
+ index_map_src.push_back(iter->var.copy_with_suffix(""));
+ }
+
+ // Step 2: Each iter on RHS has a group of corresponding iters on LHS.
Initialize the fusion
+ // result for each group of iters on LHS.
+ Map<Var, PrimExpr> fused_lhs_iters;
+ for (const auto& iter : extractor_->rhs_iters_) {
+ fused_lhs_iters.Set(iter->var, 0);
+ }
+
+ // Step 3: Fuse LHS iters mapped to the same RHS iter
+ for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) {
+ const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var;
+ const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var];
+ if (rhs_candidates.empty()) {
+ // put unmapped iters at the beginning
+ index_map_tgt.push_back(index_map_src[i]);
+ } else if (rhs_candidates.size() == 1) {
+ Var rhs_var = *rhs_candidates.begin();
+ PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var);
+ PrimExpr updated_fused_lhs =
+ fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
+ fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
+ } else {
+ // non-unique mapping is not supported
+ return {};
+ }
+ }
+ for (const auto& iter : extractor_->rhs_iters_) {
+ index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
+ }
+ // At most one mapping is supported.
+ return {IndexMap(index_map_src, index_map_tgt)};
+ }
+
+ private:
+ // The extractor that has extracted information for auto tensorization from
the workload and the
+ // tensor intrin.
+ const AutoTensorizeComparator* extractor_;
+ // The arithmetic analyzer.
+ arith::Analyzer* analyzer_;
+ /*! \brief Potential mappings on RHS for each variable on LHS */
+ std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual>
lhs_feasible_vars_;
+};
+
+Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const
tir::ScheduleState& self,
+ const
tir::StmtSRef& block_sref,
+ const
tir::PrimFunc& desc_func) {
+ arith::Analyzer analyzer;
+ const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+ // Step 1. Analyze desc_func, extract its block, loops and loop vars
+ TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer,
desc_func);
+ // Step 2. Check if `desc_block` matches `block`
+ // Ignore the scope of buffers when comparing, since we can do
cache_read/write
+ const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+ const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block,
scope_sref);
+ AutoTensorizeComparator extractor(self->mod);
+ if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) {
+ return NullOpt;
+ }
+ Array<IndexMap> mappings =
AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer);
+ if (mappings.empty()) {
+ return NullOpt;
+ }
+ ObjectPtr<AutoTensorizeMappingInfoNode> ret =
make_object<AutoTensorizeMappingInfoNode>();
+ ret->mappings = std::move(mappings);
+ ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_);
+ ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_);
+ ret->lhs_iters = std::move(extractor.lhs_iters_);
+ ret->rhs_iters = std::move(extractor.rhs_iters_);
+ return AutoTensorizeMappingInfo(ret);
+}
+
+TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode);
+
+TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo")
+ .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
+ return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block),
desc_func);
+ });
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/ir_comparator.cc
b/src/tir/schedule/ir_comparator.cc
index 58c502379a..d8ac40ef05 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -333,12 +333,12 @@ bool TensorizeComparator::CompareBufferAccess(const T*
lhs, const T* rhs) {
return true;
}
-template <typename T, typename F>
-bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>&
rhs, F cmp) {
+template <typename T, typename Self, typename F>
+bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>&
rhs, F Self::*cmp) {
if (lhs.same_as(rhs)) return true;
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
- if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+ if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false;
}
return true;
}
@@ -355,5 +355,125 @@ void TensorizeComparator::EmitError(const std::string&
error_message) {
error_messages_.push_back(error_message);
}
+/******** AutoTensorize Extractor ********/
+
+bool AutoTensorizeComparator::VisitExprDefault_(const Object* op, const
PrimExpr& other) {
+ return false;
+}
+
+bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt&
other) {
+ return false;
+}
+
+bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BlockNode>();
+ // Check block equality.
+ // All iter vars and buffer regions including the order should match.
+ // When checking iter vars, DefEqual is used to remap variables.
+ if (!is_scope_block) {
+ if (!CompareArray(op->iter_vars, rhs->iter_vars,
&AutoTensorizeComparator::CompareIterVar)) {
+ return false;
+ }
+ if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+ return false;
+ }
+ if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
+ &AutoTensorizeComparator::CompareBuffer)) {
+ return false;
+ }
+ for (const IterVar& block_iter : op->iter_vars) {
+ inner_iter_dom_map_.Set(block_iter->var,
arith::IntSet::FromRange(block_iter->dom));
+ }
+ } else {
+ auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters)
-> bool {
+ for (const auto& iter : op->iter_vars) {
+ analyzer_.Bind(iter->var, iter->dom);
+ if (iter->iter_type == IterVarType::kDataPar ||
+ iter->iter_type == IterVarType::kCommReduce) {
+ iters.push_back(iter);
+ } else {
+ return false;
+ }
+ }
+ return true;
+ };
+ if (!collect_iter(op, lhs_iters_)) {
+ return false;
+ }
+ if (!collect_iter(rhs, rhs_iters_)) {
+ return false;
+ }
+ }
+ is_scope_block = false;
+ return VisitStmt(op->body, rhs->body);
+}
+
+bool AutoTensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer&
rhs) {
+ if (lhs.same_as(rhs)) return true;
+ auto it = rhs_buffer_map_.find(rhs);
+ bool equal;
+ if (it != rhs_buffer_map_.end()) {
+ equal = (*it).second.same_as(lhs);
+ } else {
+ // Remap both buffer itself and buffer data, skip buffer shape and scope
+ equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype;
+ if (equal) {
+ rhs_buffer_map_[rhs] = lhs;
+ lhs_buffer_map_[lhs] = rhs;
+ }
+ }
+ return equal;
+}
+
+bool AutoTensorizeComparator::VisitStmt_(const BufferStoreNode* op, const
Stmt& other) {
+ const auto* rhs = other.as<BufferStoreNode>();
+ return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool AutoTensorizeComparator::VisitExpr_(const BufferLoadNode* op, const
PrimExpr& other) {
+ const auto* rhs = other.as<BufferLoadNode>();
+ return CompareBufferAccess(op, rhs);
+}
+
+template <typename T>
+bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
+ if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
+ auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer);
+ if (it_lhs == lhs_buffer_indices_map_.end()) {
+ if (rhs_buffer_indices_map_.find(rhs->buffer) !=
rhs_buffer_indices_map_.end()) {
+ return false;
+ }
+ std::vector<PrimExpr> lhs_indices;
+ for (const auto& index : lhs->indices) {
+ lhs_indices.push_back(analyzer_.Simplify(index));
+ }
+ for (const auto& index : rhs->indices) {
+ if (!index.template as<VarNode>()) return false;
+ }
+ lhs_buffer_indices_map_[lhs->buffer] = lhs_indices;
+ rhs_buffer_indices_map_[rhs->buffer] = rhs->indices;
+ } else {
+ auto it_rhs = rhs_buffer_indices_map_.find(rhs->buffer);
+ if (it_rhs == rhs_buffer_indices_map_.end()) {
+ return false;
+ }
+ auto indices_check = [&](const Array<PrimExpr>& indices,
+ const Array<PrimExpr>& old_indices) -> bool {
+ if (indices.size() != old_indices.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) {
+ return false;
+ }
+ }
+ return true;
+ };
+ if (!indices_check(lhs->indices, it_lhs->second)) return false;
+ if (!indices_check(rhs->indices, it_rhs->second)) return false;
+ }
+ return true;
+}
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h
index 359677d885..394d828673 100644
--- a/src/tir/schedule/ir_comparator.h
+++ b/src/tir/schedule/ir_comparator.h
@@ -90,8 +90,8 @@ class TensorizeComparator : public ExprComparator, public
StmtComparator {
bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const
Map<String, ObjectRef>& rhs);
template <typename T>
bool CompareBufferAccess(const T* lhs, const T* rhs);
- template <typename T, typename F>
- bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp);
+ template <typename T, typename Self, typename F>
+ bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp);
bool CompareRange(const Range& lhs, const Range& rhs);
bool CompareIterVar(const IterVar& lhs, const IterVar& rhs);
void EmitError(const std::string& error_message);
@@ -110,6 +110,54 @@ class TensorizeComparator : public ExprComparator, public
StmtComparator {
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual>
equal_map_;
};
+/*!
+ * \brief IR comparator for auto tensorization.
+ * This comparator is used to extract correspondence between the IR of the
workload (LHS) and the
+ * tensor intrin (RHS). Unlike `TensorizeComparator`, this comparator has
relaxed requirements
+ * during comparison. It ignores the loop structure (number of loops and their
extents) and buffer
+ * indices. It only requires the LHS and the RHS to have the same arithmetic
operations and the same
+ * dtype. With such relaxed requirements, workloads that can only match the
tensor intrin after
+ * certain transformations (e.g. im2col for conv2d) are allowed for auto
tensorization.
+ */
+class AutoTensorizeComparator : public TensorizeComparator {
+ public:
+ explicit AutoTensorizeComparator(const IRModule& lhs_mod)
+ : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {}
+
+ private:
+ bool VisitExprDefault_(const Object* op, const PrimExpr& other) override;
+ bool VisitStmtDefault_(const Object* op, const Stmt& other) override;
+
+ bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
+ bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+
+ bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
+
+ bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override;
+ template <typename T>
+ bool CompareBufferAccess(const T* lhs, const T* rhs);
+
+ public:
+ // Additional information extracted from LHS (the workload) and RHS (the
tensor intrin).
+
+ /*! \brief Block iters in the LHS stmt. */
+ std::vector<IterVar> lhs_iters_;
+ /*! \brief Block iters in the RHS stmt. */
+ std::vector<IterVar> rhs_iters_;
+ /*! \brief The buffer and its access indices in the LHS stmt. */
+ std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
+ lhs_buffer_indices_map_;
+ /*! \brief The buffer and its access indices in the RHS stmt. */
+ std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
+ rhs_buffer_indices_map_;
+ /*! \brief Map from LHS buffer to RHS buffer */
+ std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> lhs_buffer_map_;
+
+ private:
+ /*! \brief The domain of the inner block iters. */
+ Map<Var, arith::IntSet> inner_iter_dom_map_;
+};
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py
b/tests/python/unittest/test_tir_schedule_analysis.py
index 19be0b8699..6761203a5a 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -16,14 +16,22 @@
# under the License.
# pylint: disable=missing-docstring
from typing import List
-
+import pytest
import tvm
+import tvm.testing
+from tvm.tir.function import TensorIntrin
from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc
+from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN
from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer,
floordiv, floormod, Schedule
from tvm.tir.analysis import expr_deep_equal
-from tvm.tir.schedule.analysis import suggest_index_map,
get_tensorize_loop_mapping, TensorizeInfo
+from tvm.tir.schedule.analysis import (
+ get_auto_tensorize_mapping_info,
+ suggest_index_map,
+ get_tensorize_loop_mapping,
+ TensorizeInfo,
+)
from tvm.script import tir as T
from tvm.tir.stmt_functor import pre_order_visit
from tvm.meta_schedule.testing import te_workload
@@ -252,9 +260,43 @@ def test_get_tensorize_loop_mapping_matmul_mma():
assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
+def check_index_map(workload, block_name, intrin_name, expected_index_map):
+ s = Schedule(workload)
+ block = s.get_block(block_name)
+ desc_func = TensorIntrin.get(intrin_name).desc
+ info = get_auto_tensorize_mapping_info(s, block, desc_func)
+ assert len(info.mappings) == 1
+ assert
IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0])
+
+
+def test_get_auto_tensorize_mapping_info_conv2d():
+ conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(4, 16, 16, 64, 64,
3, 1, 1))
+ check_index_map(
+ conv2d,
+ "conv2d_nhwc",
+ WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+ lambda n, h, w, c, rh, rw, rc: (n * 256 + h * 16 + w, c, rh * 192 + rw
* 64 + rc),
+ )
+
+
+def test_get_auto_tensorize_mapping_info_conv2d_unit_batch():
+ conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(1, 16, 16, 64, 64,
3, 1, 1))
+ check_index_map(
+ conv2d,
+ "conv2d_nhwc",
+ WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+ # unit iter is not mapped
+ lambda n, h, w, c, rh, rw, rc: (n, h * 16 + w, c, rh * 192 + rw * 64 +
rc),
+ )
+
+
[email protected]("b,m,n,k", [(1, 512, 512, 512), (16, 32, 32, 32)])
+def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k):
+ matmul = create_prim_func(te_workload.batch_matmul_nkkm_f16(b, m, n, k))
+ check_index_map(
+ matmul, "Z", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, lambda b, m, n, k:
(b, m, n, k)
+ )
+
+
if __name__ == "__main__":
- test_suggest_index_map_simple()
- test_suggest_index_map_bijective()
- test_get_tensorize_loop_mapping_dense_vnni()
- test_get_tensorize_loop_mapping_conv2d_nchwc_vnni()
- test_get_tensorize_loop_mapping_matmul_mma()
+ tvm.testing.main()