masahi commented on code in PR #11050:
URL: https://github.com/apache/tvm/pull/11050#discussion_r853504746


##########
src/tir/schedule/analysis/analysis.cc:
##########
@@ -2028,5 +2034,107 @@ bool NeedsRFactorOrCrossThreadReduction(const 
tir::ScheduleState& self,   //
   }
 }
 
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& 
block_sref,
+                                                const tir::PrimFunc& 
desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  const tir::BlockRealizeNode* desc_block = nullptr;
+  std::vector<const tir::ForNode*> desc_loops;
+  std::unordered_set<const tir::VarNode*> desc_loop_vars;
+  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+  ICHECK(desc_scope_realize);
+  {
+    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+                    &analyzer](const ObjectRef& obj) -> bool {
+      // Extract the block
+      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+        desc_block = block;
+        return false;
+      }
+      // Extract loops
+      if (const auto* loop = obj.as<tir::ForNode>()) {
+        desc_loops.push_back(loop);
+        desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer.CanProve(loop->min == 0)) {
+          return false;
+        }
+      }
+      return true;
+    };
+    tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+    std::reverse(desc_loops.begin(), desc_loops.end());
+    ICHECK(desc_block);
+  }
+  // Step 2. Collect loops from block_sref
+  const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, 
scope_sref);
+  std::vector<const tir::ForNode*> block_loops;
+  std::unordered_set<const tir::VarNode*> block_loop_vars;
+  {
+    for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = 
loop_sref->parent) {
+      const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+      if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+        break;
+      }
+      block_loops.push_back(loop);
+      block_loop_vars.insert(loop->loop_var.get());
+      if (!analyzer.CanProve(loop->min == 0)) {
+        return NullOpt;
+      }
+    }
+    std::reverse(block_loops.begin(), block_loops.end());
+  }
+  // Step 3. Map from block loops to desc block loops
+  ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+  const int n_block_vars = block->iter_values.size();
+  const int n_desc_vars = desc_block->iter_values.size();
+  const int offset = n_block_vars - n_desc_vars;
+
+  if (offset < 0) {
+    return NullOpt;
+  }
+
+  const std::vector<IterVarType> iter_types_block = 
GetBlockVarTypes(block_sref);
+  const std::vector<IterVarType> iter_types_desc = 
GetBlockVarTypes(desc_block->block.get());
+
+  ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+  ICHECK(block_loops.size() == iter_types_block.size());
+
+  int next_block_ind = block_loops.size() - 1;
+  for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
+    const tir::ForNode* desc_loop = desc_loops[i_desc];
+    const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
+    if (!int_desc_extent) continue;
+
+    for (int i_block = next_block_ind; i_block >= 0; --i_block) {
+      const tir::ForNode* block_loop = block_loops[i_block];
+      const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();
+
+      if (!int_block_extent) continue;
+      if (int_block_extent->value % int_desc_extent->value != 0) continue;
+      if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue;
+
+      const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
+      ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
+      next_block_ind = i_block - 1;
+      break;
+    }
+  }

Review Comment:
   Thanks @spectrometerHBH, I now understand the original code and was able to 
integrate the original logic to support loop permutations. Please have a look 
at the current diff, also cc @vinx13 @Hzfengsy @MasterJH5574 
   
   The key difference between the original code and the code I submitted 
yesterday was that, my code was looking at only the loop nest (`ForNode`) to 
determine the mapping, while @spectrometerHBH's mapping logic is based on 
`iter_var/value` of the block (so invariant to the order of the loop nest). 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to