Copilot commented on code in PR #18823:
URL: https://github.com/apache/tvm/pull/18823#discussion_r2854762076


##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -508,7 +508,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
         //
         // The former may cause dead lock as there is a divergent
         // branch with a warp sync call inside.
-        PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), 
mask_buffer, val, offset);
+        bool cast_offset_to_uint = target_->kind->name == "webgpu";
+        PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), 
mask_buffer, val, offset,
+                                     cast_offset_to_uint);

Review Comment:
   For WebGPU, WGSL subgroupShuffle* builtins require the lane/delta argument 
to be `u32`, but this change only casts the shuffle_down offset. Other call 
sites (e.g. the tvm_warp_shuffle used for broadcasting the lane-0 result) can 
still pass an `i32` lane id, which will codegen to `subgroupShuffle(x, <i32>)` 
and fail WGSL type-checking. Consider ensuring all WebGPU warp shuffle 
lane/delta values are cast to `u32` (either by making WebGPU always cast in 
WarpShuffle, or by handling the cast in the WebGPU intrinsic lowering).



##########
src/target/source/intrin_rule_webgpu.cc:
##########
@@ -32,6 +32,29 @@ namespace intrin {
 
 using tir::FLowerIntrinsic;
 
+// warp-level primitives. Follows implementation in intrin_rule_metal.cc
+struct WebGPUWarpIntrinsic {
+    const Op operator()(DataType t, const Op& orig_op) const {
+        if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
+            return Op::Get("tir.webgpu.subgroup_shuffle");
+        } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
+            return Op::Get("tir.webgpu.subgroup_shuffle_up");
+        } else {
+            ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
+            return Op::Get("tir.webgpu.subgroup_shuffle_down");
+        }
+    }
+};
+
+template <typename T>
+static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
+    const CallNode* call = e.as<CallNode>();
+    ICHECK(call != nullptr);
+    ICHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
+    ffi::Array<PrimExpr> webgpu_args{{call->args[1], call->args[2]}};
+    return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), 
webgpu_args);

Review Comment:
   The newly added WebGPU warp shuffle lowering block uses an indentation style 
that differs from the rest of this file (4-space indentation vs the surrounding 
2-space style). Please reformat these lines to match the existing file 
formatting for consistency.
   ```suggestion
     const Op operator()(DataType t, const Op& orig_op) const {
       if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
         return Op::Get("tir.webgpu.subgroup_shuffle");
       } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
         return Op::Get("tir.webgpu.subgroup_shuffle_up");
       } else {
         ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
         return Op::Get("tir.webgpu.subgroup_shuffle_down");
       }
     }
   };
   
   template <typename T>
   static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
     const CallNode* call = e.as<CallNode>();
     ICHECK(call != nullptr);
     ICHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
     ffi::Array<PrimExpr> webgpu_args{{call->args[1], call->args[2]}};
     return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), 
webgpu_args);
   ```



##########
src/target/source/intrin_rule_webgpu.cc:
##########
@@ -32,6 +32,29 @@ namespace intrin {
 
 using tir::FLowerIntrinsic;
 
+// warp-level primitives. Follows implementation in intrin_rule_metal.cc
+struct WebGPUWarpIntrinsic {
+    const Op operator()(DataType t, const Op& orig_op) const {
+        if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
+            return Op::Get("tir.webgpu.subgroup_shuffle");
+        } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
+            return Op::Get("tir.webgpu.subgroup_shuffle_up");
+        } else {
+            ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
+            return Op::Get("tir.webgpu.subgroup_shuffle_down");
+        }
+    }
+};
+
+template <typename T>
+static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
+    const CallNode* call = e.as<CallNode>();
+    ICHECK(call != nullptr);
+    ICHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
+    ffi::Array<PrimExpr> webgpu_args{{call->args[1], call->args[2]}};

Review Comment:
   `DispatchWebGPUShuffle` forwards `call->args[2]` directly as the lane/delta 
argument, but WGSL subgroupShuffle/subgroupShuffleUp/subgroupShuffleDown expect 
that parameter to be `u32`. Since WebGPU thread indices are typically `u32` 
cast to `i32` in codegen, the delta/lane often ends up as `i32`, producing 
invalid WGSL. Please cast `call->args[2]` to `u32` (or otherwise enforce u32) 
in this dispatcher so all shuffle variants/call sites are handled consistently.
   ```suggestion
       PrimExpr lane = Cast(DataType::UInt(32), call->args[2]);
       ffi::Array<PrimExpr> webgpu_args{{call->args[1], lane}};
   ```



##########
src/target/source/codegen_webgpu.cc:
##########
@@ -118,7 +121,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
   }
 }
 
-CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
+CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
+  enable_subgroups_ = 
target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));

Review Comment:
   `enable subgroups;` is currently controlled only by the `supports_subgroups` 
target attr, but subgroup shuffle ops can be emitted whenever 
`thread_warp_size` is set > 1 (e.g. if a user sets `-thread-warp-size=32` 
directly on the WebGPU target). In that case, the generated WGSL would contain 
`subgroupShuffle*` calls without the required `enable subgroups;` directive. To 
avoid this inconsistent state, consider deriving `enable_subgroups_` from 
`thread_warp_size > 1` as well, or emitting a clear error if subgroup ops are 
encountered while `supports_subgroups` is false.
   ```suggestion
     Bool supports_subgroups = 
target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
     Optional<Integer> thread_warp_size = 
target_->GetAttr<Integer>("thread_warp_size");
     bool warp_uses_subgroups =
         thread_warp_size.defined() && thread_warp_size.value()->value > 1;
     if (warp_uses_subgroups && !supports_subgroups) {
       LOG(FATAL) << "WebGPU target has thread_warp_size=" << 
thread_warp_size.value()->value
                  << " but does not support subgroups. Either enable the 
'supports_subgroups' "
                  << "target attribute or set thread_warp_size <= 1.";
     }
     enable_subgroups_ = supports_subgroups || warp_uses_subgroups;
   ```



##########
src/target/target_kind.cc:
##########
@@ -424,8 +424,27 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
     // Tags
     .set_default_keys({"vulkan", "gpu"});
 
+/*!
+ * \brief Update WebGPU target attributes based on subgroup support.
+ * When supports_subgroups is true, set thread_warp_size to 32 so that
+ * TIR lowering uses warp-level shuffle reductions instead of shared memory.
+ */
+TargetJSON UpdateWebGPUAttrs(TargetJSON target) {
+  if (target.count("supports_subgroups")) {
+    bool subgroups = Downcast<Bool>(target.at("supports_subgroups"));
+    if (subgroups) {
+      target.Set("thread_warp_size", int64_t(32));
+    }
+  }
+  return target;
+}
+
 TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
     .add_attr_option<int64_t>("max_num_threads", 256)
+    .add_attr_option<bool>("supports_subgroups", false)
+    // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction 
returns false, so no subgroup ops are emitted.
+    .add_attr_option<int64_t>("thread_warp_size", 1)
+    .set_target_parser(UpdateWebGPUAttrs)
     .set_default_keys({"webgpu", "gpu"});

Review Comment:
   This PR introduces new WebGPU target attributes (`supports_subgroups`) and a 
target parser side effect that mutates `thread_warp_size`. There are existing 
Python tests covering target parsing and LowerThreadAllreduce behavior, but 
none validating the new WebGPU defaults/gating. Please add a unit test to 
assert (1) `Target('webgpu')` defaults `thread_warp_size` to 1, and (2) 
`Target({'kind':'webgpu','supports_subgroups': True})` results in 
`thread_warp_size==32` (and ideally that subgroup shuffles are only emitted in 
the latter case).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to