roastduck commented on a change in pull request #5498:
URL: https://github.com/apache/incubator-tvm/pull/5498#discussion_r421887613



##########
File path: src/tir/transforms/lower_thread_allreduce.cc
##########
@@ -330,6 +482,59 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
                    {StringImmNode::make(sync)},
                    CallNode::Intrinsic));
   }
+
+  // Emit warp shuffle intrinsic calls.
+  PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val,
+                       int delta_or_lane) {
+    PrimExpr pred = const_true(1);
+    PrimExpr index(0);
+    PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred);
+    PrimExpr width = IntImm(DataType::Int(32), warp_size_);
+    Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane),
+                         width, width};
+    return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic);
+  }
+
+  // Check if this is a reduction on threadIdx.x and its extent matches
+  // the warp size.
+  //
+  // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
+  bool is_warp_reduction(const std::vector<DataType>& types) const {
+    // Only cuda target supports warp reductions.
+    if (target_->target_name != "cuda") return false;
+
+    // Supported types:
+    // {u}int, {u}long, {u}long long, float, double, half/half2
+    if (std::any_of(types.begin(), types.end(), [](DataType ty) {
+          if (ty.is_float16()) return ty.lanes() > 2;
+          if (ty.is_vector()) return true;
+          return ty.bytes() < 4 || ty.bytes() > 8;
+        })) {
+      return false;
+    }
+    if (thread_extents_.empty()) {
+      return false;
+    }
+
+    const AttrStmtNode* op = thread_extents_.back();

Review comment:
       Is the `AttrStmtNode` for reduction axis guaranteed to be the inner most 
one? If not, we can't use `back()` here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to