This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9bba7580b0 [TIR, analysis] Add GetAutoTensorizeMappingInfo to generate 
transforms for auto tensorization (#11740)
9bba7580b0 is described below

commit 9bba7580b0dcaea4963bd6b35df0bf6bf867b8ff
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Jun 18 18:04:52 2022 -0700

    [TIR, analysis] Add GetAutoTensorizeMappingInfo to generate transforms for 
auto tensorization (#11740)
    
    This PR added a utility function `GetAutoTensorizeMappingInfo` to propose 
mapping from workload block iters to the iters in the tensor intrin. An example 
usage is conv2d, where the computation block has more iters than the matmul 
tensor intrin.
---
 python/tvm/meta_schedule/testing/te_workload.py    |  68 +++++
 python/tvm/tir/schedule/analysis.py                |  34 +++
 src/tir/schedule/analysis.h                        |  50 ++++
 src/tir/schedule/analysis/analysis.cc              | 275 +++++++++++++++++++--
 src/tir/schedule/ir_comparator.cc                  | 126 +++++++++-
 src/tir/schedule/ir_comparator.h                   |  52 +++-
 .../python/unittest/test_tir_schedule_analysis.py  |  56 ++++-
 7 files changed, 630 insertions(+), 31 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/te_workload.py 
b/python/tvm/meta_schedule/testing/te_workload.py
index 52f5f49b0a..28a2df628c 100644
--- a/python/tvm/meta_schedule/testing/te_workload.py
+++ b/python/tvm/meta_schedule/testing/te_workload.py
@@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]:  # 
pylint: disable=invalid-
     return (a, b)
 
 
+def conv2d_nhwc_f16(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    dilation: int = 1,
+    groups: int = 1,
+):
+    inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16")
+    weight = te.placeholder(
+        (kernel_size, kernel_size, CI // groups, CO), name="weight", 
dtype="float16"
+    )
+    batch_size, in_h, in_w, _ = inputs.shape
+    k_h, k_w, channel_per_group, out_channel = weight.shape
+    out_channel_per_group = out_channel // groups
+
+    out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+    out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+    rh = te.reduce_axis((0, k_h), name="rh")
+    rw = te.reduce_axis((0, k_w), name="rw")
+    rc = te.reduce_axis((0, channel_per_group), name="rc")
+
+    padded = topi.nn.pad(inputs, [0, padding, padding, 0])
+    output = te.compute(
+        (batch_size, out_h, out_w, out_channel),
+        lambda n, h, w, co: te.sum(
+            (
+                tir.Cast(
+                    value=padded[
+                        n,
+                        h * stride + rh * dilation,
+                        w * stride + rw * dilation,
+                        co // out_channel_per_group * channel_per_group + rc,
+                    ],
+                    dtype="float32",
+                )
+                * tir.Cast(value=weight[rh, rw, rc, co], dtype="float32")
+            ),
+            axis=[rh, rw, rc],
+        ),
+        name="conv2d_nhwc",
+    )
+    return (inputs, weight, output)
+
+
+def batch_matmul_nkkm_f16(  # pylint: disable=invalid-name,missing-docstring
+    B: int,
+    N: int,
+    M: int,
+    K: int,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    x = te.placeholder((B, N, K), name="X", dtype="float16")
+    y = te.placeholder((B, K, M), name="Y", dtype="float16")
+    k = te.reduce_axis((0, K), name="k")
+    z = te.compute(  # pylint: disable=invalid-name
+        (B, N, M),
+        lambda b, i, j: te.sum(
+            tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]), 
axis=[k]
+        ),
+        name="Z",
+    )
+    return (x, y, z)
+
+
 def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
     workload_func, params = CONFIGS[name]
     return te.create_prim_func(workload_func(*params[idx]))  # type: ignore
diff --git a/python/tvm/tir/schedule/analysis.py 
b/python/tvm/tir/schedule/analysis.py
index 71ff024217..cdb4aa9cfa 100644
--- a/python/tvm/tir/schedule/analysis.py
+++ b/python/tvm/tir/schedule/analysis.py
@@ -87,3 +87,37 @@ def get_tensorize_loop_mapping(
         TensorizeInfo structure if a valid mapping is found, None otherwise
     """
     return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func)  # type: 
ignore
+
+
+@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
+class AutoTensorizeMappingInfo(Object):
+    """Necessary information used to perform transformations for 
tensorization."""
+
+
+def get_auto_tensorize_mapping_info(
+    sch: Schedule, block: BlockRV, desc_func: PrimFunc
+) -> Optional[AutoTensorizeMappingInfo]:
+    """Get mapping info between a target block and an intrinsic description 
including layout
+    transformations to apply.
+
+    Parameters
+    ----------
+    sch : Schedule
+        The schedule to be tensorized
+    block : BlockRV
+        The compute block for auto tensorization
+    desc_func : PrimFunc
+        The prim func describing the computation to be tensorized
+
+    Returns
+    -------
+    auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo]
+        AutoTensorizeMappingInfo structure if potential mappings found, None 
otherwise.
+
+    Note
+    ----
+    Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can 
be tensorized.
+    We will need to apply the suggested layout transformations and then match 
against the tensor
+    intrinsics.
+    """
+    return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func)  # 
type: ignore
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 5adc4f8f1b..b30cef829f 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -707,6 +707,56 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const 
tir::ScheduleState& self,
                                                 const tir::StmtSRef& 
block_sref,
                                                 const tir::PrimFunc& 
desc_func);
 
+/*!\brief Necessary information used to perform transformations for 
tensorization */
+class AutoTensorizeMappingInfoNode : public Object {
+ public:
+  /*! \brief Possible mappings to apply to block iters */
+  Array<IndexMap> mappings;
+
+  /* Additional information from AutoTensorizeComparator */
+
+  /*! \brief Mapping from LHS buffer to RHS buffer */
+  Map<Buffer, Buffer> lhs_buffer_map;
+  /*! \brief Buffer indices on RHS */
+  Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
+  /*! \brief Block iters on LHS */
+  Array<IterVar> lhs_iters;
+  /*! \brief Block iters on RHS */
+  Array<IterVar> rhs_iters;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("mappings", &mappings);
+    v->Visit("lhs_buffer_map", &lhs_buffer_map);
+    v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
+    v->Visit("lhs_iters", &lhs_iters);
+    v->Visit("rhs_iters", &rhs_iters);
+  }
+
+  static constexpr const char* _type_key = 
"tir.schedule.AutoTensorizeMappingInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object);
+};
+
+class AutoTensorizeMappingInfo : public ObjectRef {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, 
ObjectRef,
+                                            AutoTensorizeMappingInfoNode);
+};
+
+/*!
+ * \brief Get mapping info between a target block and an intrinsic description 
including layout
+ * transformations to apply.
+ * \param self The schedule state
+ * \param block_sref The compute block for auto tensorization
+ * \param desc_func The prim func describing the computation to be tensorized
+ * \return AutoTensorizeMappingInfo structure if a potential mapping is found, 
NullOpt otherwise.
+ * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the 
block can be tensorized.
+ * We will need to apply the suggested layout transformations and then match 
against the tensor
+ * intrinsics.
+ */
+Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const 
ScheduleState& self,
+                                                               const StmtSRef& 
block_sref,
+                                                               const PrimFunc& 
desc_func);
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 7def8b8674..3ee1ed28b8 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -19,6 +19,7 @@
 #include <tvm/runtime/container/optional.h>
 #include <tvm/tir/expr.h>
 
+#include "../ir_comparator.h"
 #include "../utils.h"
 
 namespace tvm {
@@ -2085,39 +2086,60 @@ bool NeedsRFactorOrCrossThreadReduction(const 
tir::ScheduleState& self,   //
 
 TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
 
-Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
-                                                const tir::StmtSRef& 
block_sref,
-                                                const tir::PrimFunc& 
desc_func) {
-  arith::Analyzer analyzer;
-  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
-  // Step 1. Analyze desc_func, extract its block, loops and loop vars
-  const tir::BlockRealizeNode* desc_block = nullptr;
+/*! \brief Auxiliary data structure of information extracted from tensor 
intrin description */
+struct TensorIntrinDescInfo {
+  /*! \brief The block of the description function, which is the (unique) 
direct child of the root
+   *         block.
+   */
+  const BlockRealizeNode* desc_block = nullptr;
+  /*! \brief The loops of the description function, in the order from outer 
loops to inner ones. */
   std::vector<const tir::ForNode*> desc_loops;
+  /*! \brief The loop variables. */
   std::unordered_set<const tir::VarNode*> desc_loop_vars;
-  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+};
+
+/*!
+ * \brief Extract auxilary information from the tensor intrin description.
+ * \param analyze The arithmetic analyzer
+ * \param desc_func The description PrimFunc
+ * \return The auxilary information
+ */
+TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer,
+                                                 const PrimFunc& desc_func) {
+  TensorIntrinDescInfo info;
+  const auto* desc_scope_realize = desc_func->body.as<BlockRealizeNode>();
   ICHECK(desc_scope_realize);
   {
-    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
-                    &analyzer](const ObjectRef& obj) -> bool {
+    auto f_visit = [&](const ObjectRef& obj) -> bool {
       // Extract the block
-      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
-        desc_block = block;
+      if (const auto* block = obj.as<BlockRealizeNode>()) {
+        info.desc_block = block;
         return false;
       }
-      // Extract loops
-      if (const auto* loop = obj.as<tir::ForNode>()) {
-        desc_loops.push_back(loop);
-        desc_loop_vars.insert(loop->loop_var.get());
-        if (!analyzer.CanProve(loop->min == 0)) {
+      // Extract the loops
+      if (const auto* loop = obj.as<ForNode>()) {
+        info.desc_loops.push_back(loop);
+        info.desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer->CanProve(loop->min == 0)) {
           return false;
         }
       }
       return true;
     };
     tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
-    std::reverse(desc_loops.begin(), desc_loops.end());
-    ICHECK(desc_block);
+    std::reverse(info.desc_loops.begin(), info.desc_loops.end());
+    ICHECK(info.desc_block);
   }
+  return info;
+}
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& 
block_sref,
+                                                const tir::PrimFunc& 
desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, 
desc_func);
   // Step 2. Collect loops from block_sref
   const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
   const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, 
scope_sref);
@@ -2138,6 +2160,9 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const 
tir::ScheduleState& self,
     std::reverse(block_loops.begin(), block_loops.end());
   }
   // Step 3. Map from block loops to desc block loops
+  const std::vector<const ForNode*>& desc_loops = desc_info.desc_loops;
+  const std::unordered_set<const VarNode*>& desc_loop_vars = 
desc_info.desc_loop_vars;
+  const BlockRealizeNode* desc_block = desc_info.desc_block;
   ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
   const int n_block_vars = block->iter_values.size();
   const int n_desc_vars = desc_block->iter_values.size();
@@ -2240,5 +2265,217 @@ 
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
       return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), 
desc_func);
     });
 
+/******** Auto Tensorization ********/
+
+/*! \brief IndexMap proposer for layout transformation in auto tensorization. 
*/
+class AutoTensorizeMappingProposer {
+ public:
+  static Array<IndexMap> ProposeMappings(const AutoTensorizeComparator* 
extractor,
+                                         arith::Analyzer* analyzer) {
+    AutoTensorizeMappingProposer proposer(extractor, analyzer);
+    proposer.CollectFeasibleSet();
+    return proposer.ProposeAllFuseMapping();
+  }
+
+ private:
+  explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* 
extractor,
+                                        arith::Analyzer* analyzer)
+      : extractor_(extractor), analyzer_(analyzer) {}
+
+  using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+  void CollectFeasibleSet() {
+    // Collect the set of potential iter var mapping between the workload and 
the tensor intrin.
+    // We analyze the appearance of each variable in the buffer indices of 
each buffer on LHS and
+    // RHS. The appearance of a variable in the buffer indices is encoded as 
bit-masks (BufferMask).
+    // Variables on the LHS and the RHS with the same bit-mask and the same 
iter type are potential
+    // mappings.
+    //
+    // For example, consider the conv2d case. We will try to match the workload
+    // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * 
W[rh, rw, rc, c]
+    // against a matmul tensor intrin
+    // C[m, n] = sum_{k} A[m, k] * B[k, n]
+    // First we extract the correspondence of the buffers: conv2d <=> C, A <=> 
X, B <=> W.
+    // Then for each variable, we extract the buffers where it is used for 
indexing.
+    // Take the variable m on the RHS as an example. m is used to index buffer 
A and C. On the LHS,
+    // we will find the variables used to index only the exact corresponding 
buffers conv2d and X
+    // (the variable is not allowed to index other buffers). In this case, n, 
h, w is used to index
+    // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, 
w} <=> m is a potential
+    // mapping.
+
+    // Note: the mapping is not unique when multiple variables on RHS has the 
same bit-mask.
+    // This is currently not supported.
+
+    using BufferMask = std::vector<bool>;
+
+    // Step 1: Assign an index to each buffer in LHS and RHS
+    std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual> 
rhs_buffer_index;
+    std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual> 
lhs_buffer_index;
+    {
+      int i = 0;
+      for (const auto& kv : extractor_->rhs_buffer_map_) {
+        const Buffer& rhs_buffer = kv.first;
+        const Buffer& lhs_buffer = kv.second;
+        rhs_buffer_index[rhs_buffer] = i;
+        lhs_buffer_index[lhs_buffer] = i;
+        ++i;
+      }
+    }
+
+    // Step 2: Compute the buffer mask
+    ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size());
+    int num_buffers = rhs_buffer_index.size();
+    std::unordered_map<const VarNode*, std::vector<bool>> rhs_buffer_masks, 
lhs_buffer_masks;
+    // helper function to initialize or update the buffer mask
+    auto update_mask = [&](const VarNode* var,
+                           std::unordered_map<const VarNode*, 
std::vector<bool>>* masks, int i) {
+      if (!masks->count(var)) {
+        (*masks)[var].resize(num_buffers);
+      }
+      (*masks)[var][i] = true;
+    };
+
+    for (const auto& it : extractor_->rhs_buffer_indices_map_) {
+      const Buffer& rhs_buffer = it.first;
+      for (const PrimExpr& rhs_index : it.second) {
+        if (const VarNode* var_node = rhs_index.as<VarNode>()) {
+          update_mask(var_node, &rhs_buffer_masks, 
rhs_buffer_index.at(rhs_buffer));
+        } else {
+          LOG(FATAL) << "ValueError: Buffer index " << rhs_index
+                     << " other that variables in tensor intrinsics is not 
supported.";
+        }
+      }
+
+      auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer);
+      ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end());
+      const Buffer& lhs_buffer = lhs_buffer_it->second;
+      for (const PrimExpr& index : 
extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) {
+        PreOrderVisit(index, [&](const ObjectRef& obj) -> bool {
+          if (const VarNode* var = obj.as<VarNode>()) {
+            update_mask(var, &lhs_buffer_masks, 
lhs_buffer_index.at(lhs_buffer));
+          }
+          return true;
+        });
+      }
+    }
+
+    // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure 
LHS and RHS vars
+    // have the same iter type.
+    std::unordered_map<BufferMask, VarSet> mask_to_rhs_vars;
+    for (const auto& kv : rhs_buffer_masks) {
+      const VarNode* rhs_var = kv.first;
+      const BufferMask& mask = kv.second;
+      mask_to_rhs_vars[mask].insert(GetRef<Var>(rhs_var));
+    }
+    std::unordered_map<const VarNode*, IterVarType> rhs_var_iter_type;
+    for (const auto& iter : extractor_->rhs_iters_) {
+      rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type);
+    }
+    for (const auto& iter : extractor_->lhs_iters_) {
+      auto& potential_mappings = lhs_feasible_vars_[iter->var];
+      VarSet rhs_candidates = 
mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]];
+      std::copy_if(
+          rhs_candidates.begin(), rhs_candidates.end(),
+          std::inserter(potential_mappings, potential_mappings.begin()),
+          [&](const Var& var) { return rhs_var_iter_type.at(var.get()) == 
iter->iter_type; });
+    }
+  }
+
+  Array<IndexMap> ProposeAllFuseMapping() {
+    // Now we have calcuated potential mapping for each iter var on LHS. For 
iters on LHS mapped to
+    // the same iter on RHS, they will be fused in the original order in LHS 
block iters. We will
+    // generate IndexMap to represent such fusion on LHS. For example, if n, 
h, w on LHS are mapped
+    // to the same iter var on RHS, we will produce index map `lambda n, h, w: 
fuse(n, h, w)`, where
+    // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn
+
+    // the parameters of the result index map, each parameter corresponds to a 
LHS iter
+    Array<Var> index_map_src;
+    // the outputs of the result index map
+    Array<PrimExpr> index_map_tgt;
+
+    // Step 1: Collect extents of LHS iters and prepare the initial indices of 
the IndexMap
+    Map<Var, PrimExpr> lhs_iter_extents;
+    for (const auto& iter : extractor_->lhs_iters_) {
+      lhs_iter_extents.Set(iter->var, iter->dom->extent);
+      index_map_src.push_back(iter->var.copy_with_suffix(""));
+    }
+
+    // Step 2: Each iter on RHS has a group of corresponding iters on LHS. 
Initialize the fusion
+    // result for each group of iters on LHS.
+    Map<Var, PrimExpr> fused_lhs_iters;
+    for (const auto& iter : extractor_->rhs_iters_) {
+      fused_lhs_iters.Set(iter->var, 0);
+    }
+
+    // Step 3: Fuse LHS iters mapped to the same RHS iter
+    for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) {
+      const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var;
+      const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var];
+      if (rhs_candidates.empty()) {
+        // put unmapped iters at the beginning
+        index_map_tgt.push_back(index_map_src[i]);
+      } else if (rhs_candidates.size() == 1) {
+        Var rhs_var = *rhs_candidates.begin();
+        PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var);
+        PrimExpr updated_fused_lhs =
+            fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
+        fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
+      } else {
+        // non-unique mapping is not supported
+        return {};
+      }
+    }
+    for (const auto& iter : extractor_->rhs_iters_) {
+      index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
+    }
+    // At most one mapping is supported.
+    return {IndexMap(index_map_src, index_map_tgt)};
+  }
+
+ private:
+  // The extractor that has extracted information for auto tensorization from 
the workload and the
+  // tensor intrin.
+  const AutoTensorizeComparator* extractor_;
+  // The arithmetic analyzer.
+  arith::Analyzer* analyzer_;
+  /*! \brief Potential mappings on RHS for each variable on LHS */
+  std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual> 
lhs_feasible_vars_;
+};
+
+Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const 
tir::ScheduleState& self,
+                                                               const 
tir::StmtSRef& block_sref,
+                                                               const 
tir::PrimFunc& desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, 
desc_func);
+  // Step 2. Check if `desc_block` matches `block`
+  // Ignore the scope of buffers when comparing, since we can do 
cache_read/write
+  const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, 
scope_sref);
+  AutoTensorizeComparator extractor(self->mod);
+  if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) {
+    return NullOpt;
+  }
+  Array<IndexMap> mappings = 
AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer);
+  if (mappings.empty()) {
+    return NullOpt;
+  }
+  ObjectPtr<AutoTensorizeMappingInfoNode> ret = 
make_object<AutoTensorizeMappingInfoNode>();
+  ret->mappings = std::move(mappings);
+  ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_);
+  ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_);
+  ret->lhs_iters = std::move(extractor.lhs_iters_);
+  ret->rhs_iters = std::move(extractor.rhs_iters_);
+  return AutoTensorizeMappingInfo(ret);
+}
+
+TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode);
+
+TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo")
+    .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
+      return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), 
desc_func);
+    });
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/ir_comparator.cc 
b/src/tir/schedule/ir_comparator.cc
index 58c502379a..d8ac40ef05 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -333,12 +333,12 @@ bool TensorizeComparator::CompareBufferAccess(const T* 
lhs, const T* rhs) {
   return true;
 }
 
-template <typename T, typename F>
-bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>& 
rhs, F cmp) {
+template <typename T, typename Self, typename F>
+bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>& 
rhs, F Self::*cmp) {
   if (lhs.same_as(rhs)) return true;
   if (lhs.size() != rhs.size()) return false;
   for (size_t i = 0; i < lhs.size(); ++i) {
-    if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+    if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false;
   }
   return true;
 }
@@ -355,5 +355,125 @@ void TensorizeComparator::EmitError(const std::string& 
error_message) {
   error_messages_.push_back(error_message);
 }
 
+/******** AutoTensorize Extractor ********/
+
+bool AutoTensorizeComparator::VisitExprDefault_(const Object* op, const 
PrimExpr& other) {
+  return false;
+}
+
+bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& 
other) {
+  return false;
+}
+
+bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& 
other) {
+  const auto* rhs = other.as<BlockNode>();
+  // Check block equality.
+  // All iter vars and buffer regions including the order should match.
+  // When checking iter vars, DefEqual is used to remap variables.
+  if (!is_scope_block) {
+    if (!CompareArray(op->iter_vars, rhs->iter_vars, 
&AutoTensorizeComparator::CompareIterVar)) {
+      return false;
+    }
+    if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+      return false;
+    }
+    if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
+                      &AutoTensorizeComparator::CompareBuffer)) {
+      return false;
+    }
+    for (const IterVar& block_iter : op->iter_vars) {
+      inner_iter_dom_map_.Set(block_iter->var, 
arith::IntSet::FromRange(block_iter->dom));
+    }
+  } else {
+    auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters) 
-> bool {
+      for (const auto& iter : op->iter_vars) {
+        analyzer_.Bind(iter->var, iter->dom);
+        if (iter->iter_type == IterVarType::kDataPar ||
+            iter->iter_type == IterVarType::kCommReduce) {
+          iters.push_back(iter);
+        } else {
+          return false;
+        }
+      }
+      return true;
+    };
+    if (!collect_iter(op, lhs_iters_)) {
+      return false;
+    }
+    if (!collect_iter(rhs, rhs_iters_)) {
+      return false;
+    }
+  }
+  is_scope_block = false;
+  return VisitStmt(op->body, rhs->body);
+}
+
+bool AutoTensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& 
rhs) {
+  if (lhs.same_as(rhs)) return true;
+  auto it = rhs_buffer_map_.find(rhs);
+  bool equal;
+  if (it != rhs_buffer_map_.end()) {
+    equal = (*it).second.same_as(lhs);
+  } else {
+    // Remap both buffer itself and buffer data, skip buffer shape and scope
+    equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype;
+    if (equal) {
+      rhs_buffer_map_[rhs] = lhs;
+      lhs_buffer_map_[lhs] = rhs;
+    }
+  }
+  return equal;
+}
+
+bool AutoTensorizeComparator::VisitStmt_(const BufferStoreNode* op, const 
Stmt& other) {
+  const auto* rhs = other.as<BufferStoreNode>();
+  return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool AutoTensorizeComparator::VisitExpr_(const BufferLoadNode* op, const 
PrimExpr& other) {
+  const auto* rhs = other.as<BufferLoadNode>();
+  return CompareBufferAccess(op, rhs);
+}
+
+template <typename T>
+bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
+  if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
+  auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer);
+  if (it_lhs == lhs_buffer_indices_map_.end()) {
+    if (rhs_buffer_indices_map_.find(rhs->buffer) != 
rhs_buffer_indices_map_.end()) {
+      return false;
+    }
+    std::vector<PrimExpr> lhs_indices;
+    for (const auto& index : lhs->indices) {
+      lhs_indices.push_back(analyzer_.Simplify(index));
+    }
+    for (const auto& index : rhs->indices) {
+      if (!index.template as<VarNode>()) return false;
+    }
+    lhs_buffer_indices_map_[lhs->buffer] = lhs_indices;
+    rhs_buffer_indices_map_[rhs->buffer] = rhs->indices;
+  } else {
+    auto it_rhs = rhs_buffer_indices_map_.find(rhs->buffer);
+    if (it_rhs == rhs_buffer_indices_map_.end()) {
+      return false;
+    }
+    auto indices_check = [&](const Array<PrimExpr>& indices,
+                             const Array<PrimExpr>& old_indices) -> bool {
+      if (indices.size() != old_indices.size()) {
+        return false;
+      }
+      for (size_t i = 0; i < indices.size(); ++i) {
+        if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) {
+          return false;
+        }
+      }
+      return true;
+    };
+    if (!indices_check(lhs->indices, it_lhs->second)) return false;
+    if (!indices_check(rhs->indices, it_rhs->second)) return false;
+  }
+  return true;
+}
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h
index 359677d885..394d828673 100644
--- a/src/tir/schedule/ir_comparator.h
+++ b/src/tir/schedule/ir_comparator.h
@@ -90,8 +90,8 @@ class TensorizeComparator : public ExprComparator, public 
StmtComparator {
   bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const 
Map<String, ObjectRef>& rhs);
   template <typename T>
   bool CompareBufferAccess(const T* lhs, const T* rhs);
-  template <typename T, typename F>
-  bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp);
+  template <typename T, typename Self, typename F>
+  bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp);
   bool CompareRange(const Range& lhs, const Range& rhs);
   bool CompareIterVar(const IterVar& lhs, const IterVar& rhs);
   void EmitError(const std::string& error_message);
@@ -110,6 +110,54 @@ class TensorizeComparator : public ExprComparator, public 
StmtComparator {
   std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_;
 };
 
+/*!
+ * \brief IR comparator for auto tensorization.
+ * This comparator is used to extract correspondence between the IR of the 
workload (LHS) and the
+ * tensor intrin (RHS). Unlike `TensorizeComparator`, this comparator has 
relaxed requirements
+ * during comparison. It ignores the loop structure (number of loops and their 
extents) and buffer
+ * indices. It only requires the LHS and the RHS to have the same arithmetic 
operations and the same
+ * dtype. With such relaxed requirements, workloads that can only match the 
tensor intrin after
+ * certain transformations (e.g. im2col for conv2d) are allowed for auto 
tensorization.
+ */
+class AutoTensorizeComparator : public TensorizeComparator {
+ public:
+  explicit AutoTensorizeComparator(const IRModule& lhs_mod)
+      : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {}
+
+ private:
+  bool VisitExprDefault_(const Object* op, const PrimExpr& other) override;
+  bool VisitStmtDefault_(const Object* op, const Stmt& other) override;
+
+  bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+
+  bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
+
+  bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override;
+  template <typename T>
+  bool CompareBufferAccess(const T* lhs, const T* rhs);
+
+ public:
+  // Additional information extracted from LHS (the workload) and RHS (the 
tensor intrin).
+
+  /*! \brief Block iters in the LHS stmt. */
+  std::vector<IterVar> lhs_iters_;
+  /*! \brief Block iters in the RHS stmt. */
+  std::vector<IterVar> rhs_iters_;
+  /*! \brief The buffer and its access indices in the LHS stmt. */
+  std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
+      lhs_buffer_indices_map_;
+  /*! \brief The buffer and its access indices in the RHS stmt. */
+  std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
+      rhs_buffer_indices_map_;
+  /*! \brief Map from LHS buffer to RHS buffer */
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> lhs_buffer_map_;
+
+ private:
+  /*! \brief The domain of the inner block iters. */
+  Map<Var, arith::IntSet> inner_iter_dom_map_;
+};
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py 
b/tests/python/unittest/test_tir_schedule_analysis.py
index 19be0b8699..6761203a5a 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -16,14 +16,22 @@
 # under the License.
 # pylint: disable=missing-docstring
 from typing import List
-
+import pytest
 import tvm
+import tvm.testing
+from tvm.tir.function import TensorIntrin
 from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc
+from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN
 
 
 from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, 
floordiv, floormod, Schedule
 from tvm.tir.analysis import expr_deep_equal
-from tvm.tir.schedule.analysis import suggest_index_map, 
get_tensorize_loop_mapping, TensorizeInfo
+from tvm.tir.schedule.analysis import (
+    get_auto_tensorize_mapping_info,
+    suggest_index_map,
+    get_tensorize_loop_mapping,
+    TensorizeInfo,
+)
 from tvm.script import tir as T
 from tvm.tir.stmt_functor import pre_order_visit
 from tvm.meta_schedule.testing import te_workload
@@ -252,9 +260,43 @@ def test_get_tensorize_loop_mapping_matmul_mma():
         assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
 
 
+def check_index_map(workload, block_name, intrin_name, expected_index_map):
+    s = Schedule(workload)
+    block = s.get_block(block_name)
+    desc_func = TensorIntrin.get(intrin_name).desc
+    info = get_auto_tensorize_mapping_info(s, block, desc_func)
+    assert len(info.mappings) == 1
+    assert 
IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0])
+
+
+def test_get_auto_tensorize_mapping_info_conv2d():
+    conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(4, 16, 16, 64, 64, 
3, 1, 1))
+    check_index_map(
+        conv2d,
+        "conv2d_nhwc",
+        WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+        lambda n, h, w, c, rh, rw, rc: (n * 256 + h * 16 + w, c, rh * 192 + rw 
* 64 + rc),
+    )
+
+
+def test_get_auto_tensorize_mapping_info_conv2d_unit_batch():
+    conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(1, 16, 16, 64, 64, 
3, 1, 1))
+    check_index_map(
+        conv2d,
+        "conv2d_nhwc",
+        WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+        # unit iter is not mapped
+        lambda n, h, w, c, rh, rw, rc: (n, h * 16 + w, c, rh * 192 + rw * 64 + 
rc),
+    )
+
+
[email protected]("b,m,n,k", [(1, 512, 512, 512), (16, 32, 32, 32)])
+def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k):
+    matmul = create_prim_func(te_workload.batch_matmul_nkkm_f16(b, m, n, k))
+    check_index_map(
+        matmul, "Z", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, lambda b, m, n, k: 
(b, m, n, k)
+    )
+
+
 if __name__ == "__main__":
-    test_suggest_index_map_simple()
-    test_suggest_index_map_bijective()
-    test_get_tensorize_loop_mapping_dense_vnni()
-    test_get_tensorize_loop_mapping_conv2d_nchwc_vnni()
-    test_get_tensorize_loop_mapping_matmul_mma()
+    tvm.testing.main()

Reply via email to