roastduck commented on a change in pull request #5498:
URL: https://github.com/apache/incubator-tvm/pull/5498#discussion_r421898436



##########
File path: src/tir/transforms/lower_thread_allreduce.cc
##########
@@ -181,52 +189,196 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
     std::sort(vpar.begin(), vpar.end());
     // the size of each index.
     int reduce_extent, group_extent;
-    int threadx_extent = 1;
     PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
     PrimExpr group_index = FlattenThread(vpar, &group_extent);
-    if (reduce_extent == 1) {
-      // special case, no reduction is needed.
-      std::vector<Stmt> stores(size);
+    std::vector<Stmt> seq;
+    std::vector<Var> shared_bufs(size);
+    std::vector<Stmt> local_vars;
+    //
+    // This is an optimization. For small reduction sizes, it may be beneficial
+    // for a single warp to performance the entire reduction. No trips to 
shared
+    // memory and no cross warp synchronizations are required.
+    // The following code emits the reduction as follows:
+    //
+    // Allocate reduction vars v[i], i = 0..size-1
+    //
+    // for offset from 16 to 1 by 2
+    //
+    //   a    <- load(v[i])
+    //   b    <- shuffle_down(load(v[i], offset))
+    //   v[i] <- reduction(a, b)
+    //
+    // broadcast results from lane 0 to all other lanes and store
+    // the final reduction result to the proper location.
+    //
+    if (is_warp_reduction(types)) {
+      // TODO(tvm-team) sub-warp reduction support.
+      CHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction";
+      //
+      // This is the index to the reduction variable, one reduction
+      // variable per warp. Local scope seems easier to reason without
+      // relying on a pattern match pass to fix it later.
+      PrimExpr index(0);
+
+      for (size_t idx = 0; idx < size; ++idx) {
+        shared_bufs[idx] = Var("red_buf"+std::to_string(idx), 
DataType::Handle());
+        PrimExpr pred = const_true(types[idx].lanes());
+        seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, 
pred));
+
+        // Uses a local variable to store the shuffled data.
+        // Later on, this allocation will be properly attached to this 
statement.
+        Var var("t" + std::to_string(idx), types[idx]);
+        Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred,
+                                    EvaluateNode::make(0));
+        local_vars.push_back(s);

Review comment:
       Are we using `var` to do shuffles, then store it to `shared_bufs[idx]`, 
and finally store it to the target buffer? Can we directly perform shuffles on 
`shared_bufs[idx]`, instead of using two sets of variables?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to