roastduck commented on a change in pull request #5600:
URL: https://github.com/apache/incubator-tvm/pull/5600#discussion_r426361872
##########
File path: src/tir/transforms/lower_warp_memory.cc
##########
@@ -213,9 +213,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_,
analyzer_).Find(op->body);
- CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
- << "Warp memory must be multiple of the extent of threadIdx.x";
- warp_group_ = alloc_size / (width_ * warp_coeff_);
+
+ // Align the local memory size. The number of elements may not
+ // be a multiple of width_ * warp_coeff_; round it up.
+ int factor = width_ * warp_coeff_;
+ warp_group_ = (alloc_size + (factor - 1)) / factor;
+ alloc_size = warp_group_ * factor;
+
Review comment:
> What’s the extent of threadidx.x?
It's 5. Thread 0\~4 are shuffled as a group (say Group 0), and Thread 5\~12
are another group (say Group 1). If we round up 0\~4 to 0\~7, Thread 5\~7 may
be both accessed by Group 0 and Group 1, which may cause some problems.
> I think the test case in this PR clearly shows rounding up to the warp
size is needed.
Yes, it is needed, but maybe not enough.
> As there is no warp level allocation of size n, we allocate n/32 elements
in each thread. If n is not a multiple of 32, we need to ”over-allocate“
slightly and predict the access.
Yes, we need to over-allocate the **buffers**, but should we also
over-allocate the **threads**? If not, maybe we should properly set the mask of
`__shlf_sync`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]