masahi commented on code in PR #11740:
URL: https://github.com/apache/tvm/pull/11740#discussion_r899790961


##########
src/tir/schedule/analysis/analysis.cc:
##########
@@ -2240,5 +2265,207 @@ 
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
       return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), 
desc_func);
     });
 
+/******** Auto Tensorization ********/
+
+/*! \brief IndexMap proposer for layout transformation in auto tensorization. 
*/
+class AutoTensorizeMappingProposer {
+ public:
+  static Array<IndexMap> ProposeMappings(const AutoTensorizeComparator* 
extractor,
+                                         arith::Analyzer* analyzer) {
+    AutoTensorizeMappingProposer proposer(extractor, analyzer);
+    proposer.CollectFeasibleSet();
+    return proposer.ProposeAllFuseMapping();
+  }
+
+ private:
+  explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* 
extractor,
+                                        arith::Analyzer* analyzer)
+      : extractor_(extractor), analyzer_(analyzer) {}
+
+  using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+  void CollectFeasibleSet() {
+    // Collect the set of potential iter var mapping between the workload and 
the tensor intrin.
+    // We analyze the appearance of each variable in the buffer indices of 
each buffer on LHS and
+    // RHS. The appearance of a variable in the buffer indices is encoded as 
bit-masks (BufferMask).
+    // Variables on the LHS and the RHS with the same bit-mask are potential 
mappings.
+    //
+    // For example, consider the conv2d case. We will try to match the workload
+    // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * 
W[rh, rw, rc, c]
+    // against a matmul tensor intrin
+    // C[m, n] = sum_{k} A[m, k] * B[k, n]
+    // First we extract the correspondence of the buffers: conv2d <=> C, A <=> 
X, B <=> W.
+    // Then for each variable, we extract the buffers where it is used for 
indexing.
+    // Take the variable m on the RHS as an example. m is used to index buffer 
A and C. On the LHS,
+    // we will find the variables used to index only the exact corresponding 
buffers conv2d and X
+    // (the variable is not allowed to index other buffers). In this case, n, 
h, w is used to index
+    // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, 
w} <=> m is a potential
+    // mapping.

Review Comment:
   How about also check the iterator types (spatial vs reduction)? 



##########
src/tir/schedule/analysis.h:
##########
@@ -707,6 +707,55 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const 
tir::ScheduleState& self,
                                                 const tir::StmtSRef& 
block_sref,
                                                 const tir::PrimFunc& 
desc_func);
 
+/*!\brief Necessary information used to perform transformations for 
tensorization */
+class AutoTensorizeMappingInfoNode : public Object {
+ public:
+  /*! \brief Possible mappings to apply to block iters */
+  Array<IndexMap> mappings;
+
+  /* Additional information from AutoTensorizeComparator */
+
+  /*! \brief Mapping from LHS buffer to RHS buffer */
+  Map<Buffer, Buffer> lhs_buffer_map;
+  /*! \brief Buffer indices on RHS */
+  Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
+  /*! \brief Block iters on LHS */
+  Array<IterVar> lhs_iters;
+  /*! \brief Block iters on RHS */
+  Array<IterVar> rhs_iters;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("mappings", &mappings);
+    v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
+    v->Visit("lhs_iters", &lhs_iters);
+    v->Visit("rhs_iters", &rhs_iters);

Review Comment:
   `lhs_buffer_map` not visited. Remove it if not needed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to