DzAvril commented on pull request #10652:
URL: https://github.com/apache/tvm/pull/10652#issuecomment-1076263251


   There is a bug that double_buffer doesn't work in tensor core conv2d 
template. The test code is the same as the attachment above. After lowered I 
found buffer `AS` and `BS` weren't doubled.
   During lowering, first need to detect double buffer variables which has 
attr_key: `double_buffer_scop` and add the buffer to unordered_set `touched_`, 
in this case the name of double buffer variable is `T_reshape.shared`.
   ```C++
   // inject_double_buffer.cc:DoubleBufferDetector
   void VisitStmt_(const AttrStmtNode* op) final {
   if (op->attr_key == attr::double_buffer_scope) {
       touched_.insert(op->node.as<VarNode>());
       StmtExprVisitor::VisitStmt_(op);
   } else {
       StmtExprVisitor::VisitStmt_(op);
   }
   }
   ```
   As tensor core conv2d template employs tensor intrin, this brings a call 
node `tir.tvm_access_ptr` and one of its parameters is `T_reshape.shared`. When 
a call node is visited by `StmtExprVisitor`, its parameters will be visited 
too. So comes to this function:
   ```C++
   // inject_double_buffer.cc:DoubleBufferDetector
   void VisitExpr_(const VarNode* op) final {
       if (touched_.count(op)) {
           touched_.erase(op);
       }
   }
   ```
   As the code shows, `T_reshape.shared` will be erased from `touched_`, so 
double_buffer doesn't work in the end.
   Then why erase the double_buffer which is a parameter of a call node? I 
guess the author expects double buffer just in load node or store node, so 
double buffer in call node is not in his/her expectation.
   The solution is simply and specify for tensor core conv2d template which 
employs tensor intrin. When visit a call node type is `tvm_access_ptr`, skip 
visit its parameters.
   ```c++
   void VisitExpr_(const CallNode* op) final {
       // do not visit var in tvm_access_ptr
       if (op->op.same_as(builtin::tvm_access_ptr())) {
           return;
       }
       StmtExprVisitor::VisitExpr_(op);
   }
   ```
   Reference to origin PR: [#405](https://github.com/apache/tvm/pull/405)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to