gemini-code-assist[bot] commented on code in PR #18823:
URL: https://github.com/apache/tvm/pull/18823#discussion_r2854736683


##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -717,11 +722,11 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
   bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, 
int reduce_extent,
                        int contiguous_reduce_extent) {
     if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
-        (target_->kind->name != "metal")) {
+        (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) 
{
       return false;
     }
 
-    need_warp_shuffle_mask_ = target_->kind->name != "metal";
+    need_warp_shuffle_mask_ = target_->kind->name != "metal" && 
target_->kind->name != "webgpu";

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   To improve maintainability, consider using `std::unordered_set` for checking 
the target kind. This makes it easier to add or remove supported targets in the 
future.
   
   ```suggestion
       const std::unordered_set<std::string> supported_targets = {"cuda", 
"rocm", "metal", "webgpu"};
       if (!supported_targets.count(target_->kind->name)) {
         return false;
       }
   
       const std::unordered_set<std::string> no_mask_targets = {"metal", 
"webgpu"};
       need_warp_shuffle_mask_ = !no_mask_targets.count(target_->kind->name);
   ```



##########
src/target/target_kind.cc:
##########
@@ -424,8 +424,27 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
     // Tags
     .set_default_keys({"vulkan", "gpu"});
 
+/*!
+ * \brief Update WebGPU target attributes based on subgroup support.
+ * When supports_subgroups is true, set thread_warp_size to 32 so that
+ * TIR lowering uses warp-level shuffle reductions instead of shared memory.
+ */
+TargetJSON UpdateWebGPUAttrs(TargetJSON target) {
+  if (target.count("supports_subgroups")) {
+    bool subgroups = Downcast<Bool>(target.at("supports_subgroups"));
+    if (subgroups) {
+      target.Set("thread_warp_size", int64_t(32));
+    }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This implementation unconditionally sets `thread_warp_size` to 32 if 
`supports_subgroups` is true. This will overwrite any user-provided value for 
`thread_warp_size`. It would be more flexible to only set this value if the 
user has not already specified one, allowing for overrides on devices that may 
have different subgroup sizes.
   
   ```suggestion
       if (subgroups && !target.count("thread_warp_size")) {
         target.Set("thread_warp_size", int64_t(32));
       }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to