wpan11nv commented on a change in pull request #5600:
URL: https://github.com/apache/incubator-tvm/pull/5600#discussion_r426350654



##########
File path: src/tir/transforms/lower_warp_memory.cc
##########
@@ -213,9 +213,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
     alloc_size *= op->dtype.lanes();
     std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
     warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, 
analyzer_).Find(op->body);
-    CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
-        << "Warp memory must be multiple of the extent of threadIdx.x";
-    warp_group_ = alloc_size / (width_ * warp_coeff_);
+
+    // Align the local memory size. The number of elements may not
+    // be a multiple of width_ * warp_coeff_; round it up.
+    int factor = width_ * warp_coeff_;
+    warp_group_ = (alloc_size + (factor - 1)) / factor;
+    alloc_size = warp_group_ * factor;
+

Review comment:
       What’s the extent of threadidx.x? I think the test case in this PR 
clearly shows rounding up to the warp size is needed. It is the same as softmax 
failures I have seen. As there is no warp level allocation of size n,  we 
allocate n/32 elements in each thread. If n is not a multiple of 32, we need to 
”over-allocate“ slightly and predict the access.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to