akaashrp commented on code in PR #18823:
URL: https://github.com/apache/tvm/pull/18823#discussion_r2916185667
##########
src/s_tir/transform/lower_thread_allreduce.cc:
##########
Review Comment:
I think we should add a check of the following form here since WebGPU does
not support f64, i64, or u64:
```
if (target_->kind->name == "webgpu" &&
std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_float() && ty.bits() > 32) return true; // f64
if ((ty.is_int() || ty.is_uint()) && ty.bits() > 32) return true; //
i64/u64
return false;
})) {
return false;
}
```
##########
src/target/target_kind.cc:
##########
@@ -427,8 +427,28 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
// Tags
.set_default_keys({"vulkan", "gpu"});
+/*!
+ * \brief Update WebGPU target attributes based on subgroup support.
+ * When supports_subgroups is true, set thread_warp_size to 32 so that
+ * TIR lowering uses warp-level shuffle reductions instead of shared memory.
+ */
+ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String,
ffi::Any> target) {
+ if (target.count("supports_subgroups")) {
+ bool subgroups = Downcast<Bool>(target.at("supports_subgroups"));
+ if (subgroups) {
Review Comment:
Could you add a comment stating the following:
1. Runtime routing on the WebLLM side guarantees subgroup size == 32
2. Runtime routing on the WebLLM side guarantees
maxComputeInvocationsPerWorkgroup >= 1024
3. This is intentionally constrained for the subgroup-enabled WASM variant
##########
src/target/source/codegen_webgpu.cc:
##########
@@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
}
}
-CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
+CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
+ enable_subgroups_ =
target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
Review Comment:
Let's add the following check here:
```
Bool supports_subgroups =
target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
int64_t thread_warp_size = target_->GetAttr<Integer>("thread_warp_size",
1).value()->value;
if (thread_warp_size > 1 && !supports_subgroups) {
LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size
<< " but supports_subgroups is false.";
}
enable_subgroups_ = supports_subgroups;
```
##########
src/s_tir/transform/lower_thread_allreduce.cc:
##########
@@ -719,11 +723,11 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent,
int reduce_extent,
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
- (target_->kind->name != "metal")) {
+ (target_->kind->name != "metal") && (target_->kind->name != "webgpu"))
{
return false;
}
Review Comment:
Add a check here of the following form:
```
if (target_->kind->name == "webgpu" && !supports_subgroups_) {
return false;
}
```
This is to avoid scenarios where a target such as
`{"kind":"webgpu","thread_warp_size":32,"supports_subgroups":false}` would
still emit subgroup ops, but the WGSL would not contain `enable subgroups;`.
##########
src/s_tir/transform/lower_thread_allreduce.cc:
##########
@@ -700,6 +700,10 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, ffi::Optional<Buffer> mask_buffer,
PrimExpr val,
PrimExpr delta_or_lane) {
+ // WebGPU's WGSL requires u32 for subgroupShuffle lane/delta arguments.
Review Comment:
I believe you can move this to DispatchWebGPUShuffle by casting
call->args[2] to UInt(32)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]