This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 814afe6b78 [Unity] Improve WebGPU codegen for large grid (#14674)
814afe6b78 is described below

commit 814afe6b78ca63bd875713c6653c4485afa18b06
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Apr 19 21:43:33 2023 -0400

    [Unity] Improve WebGPU codegen for large grid (#14674)
    
    [Unity][WEBGPU] Improve WebGPU codegen for large grid
    
    Background: as of now WebGPU do not allow grid size bigger than 65535
    so we have to factorize the gridDim.x when it is too big and spread
    it across gridDim.x and gridDim.z.
    
    This approach however is not always possible. This PR pass in extra 
parameter packDimX
    which records the original requested dim overpad if factorization is not
    possible and immediately returns if the index is out of bound
    
    This PR improves webgpu codegen to handle large launch grid
---
 src/target/source/codegen_webgpu.cc | 65 +++++++++++++++++++------------------
 web/src/webgpu.ts                   | 60 +++++++++++++++++++---------------
 web/tests/python/webgpu_rpc_test.py |  4 +--
 3 files changed, 69 insertions(+), 60 deletions(-)

diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index ab4c2605bd..e933d58e17 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -192,40 +192,38 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f, bool skip_re
 
   // Store all pod arguments in a single buffer of int32
   // do bitcast to change to other data types
-  if (pod_args.size() != 0) {
-    std::string type_pod_args = name_supply_->FreshName("PODArgs");
-    std::string val_pod_args = name_supply_->FreshName("podArgs");
-
-    this->decl_stream << "\nstruct " << type_pod_args << " {\n";
-
-    for (size_t i = 0; i < pod_args.size(); ++i) {
-      Var v = pod_args[i];
-      ICHECK(!v.dtype().is_handle());
-      std::string vid = AllocVarID(v.get());
-
-      if (v.dtype() == DataType::Int(32)) {
-        this->decl_stream << "  " << vid << ": i32";
-      } else if (v.dtype() == DataType::UInt(32)) {
-        this->decl_stream << "  " << vid << ": u32";
-      } else if (v.dtype() == DataType::Float(32)) {
-        this->decl_stream << "  " << vid << ": f32";
-      } else {
-        LOG(FATAL) << "Do not support pod argument type " << v.dtype();
-      }
-      if (i + 1 != pod_args.size()) {
-        this->decl_stream << ",\n";
-      } else {
-        this->decl_stream << "\n}\n";
-      }
-      // value ref
-      std::ostringstream vref;
-      vref << val_pod_args << "." << vid;
-      var_idmap_[v.get()] = vref.str();
+  // always pass gridDimX in to get around of the 65535 gridDim
+  // restrictions in some platforms
+  std::string type_pod_args = name_supply_->FreshName("PODArgs");
+  std::string val_pod_args = name_supply_->FreshName("podArgs");
+  std::string packGridDimX = name_supply_->FreshName("packGridDimX");
+
+  this->decl_stream << "\nstruct " << type_pod_args << " {\n";
+
+  for (size_t i = 0; i < pod_args.size(); ++i) {
+    Var v = pod_args[i];
+    ICHECK(!v.dtype().is_handle());
+    std::string vid = AllocVarID(v.get());
+
+    if (v.dtype() == DataType::Int(32)) {
+      this->decl_stream << "  " << vid << ": i32";
+    } else if (v.dtype() == DataType::UInt(32)) {
+      this->decl_stream << "  " << vid << ": u32";
+    } else if (v.dtype() == DataType::Float(32)) {
+      this->decl_stream << "  " << vid << ": f32";
+    } else {
+      LOG(FATAL) << "Do not support pod argument type " << v.dtype();
     }
-
-    this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
-                      << "var<uniform> " << val_pod_args << " : " << 
type_pod_args << ";\n\n";
+    this->decl_stream << ",\n";
+    // value ref
+    std::ostringstream vref;
+    vref << val_pod_args << "." << vid;
+    var_idmap_[v.get()] = vref.str();
   }
+  this->decl_stream << "  " << packGridDimX << ": u32\n}\n";
+
+  this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
+                    << "var<uniform> " << val_pod_args << " : " << 
type_pod_args << ";\n\n";
 
   // setup thread tags and param access in launch param tags;
   if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
@@ -249,6 +247,9 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f, bool skip_re
                << "  @builtin(num_workgroups) gridDim : vec3<u32>,\n"
                << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
                << ") {\n";
+  // skip out of bound grids
+  this->stream << "  if (blockIdx.z * gridDim.x + blockIdx.x > "  // NOLINT(*)
+               << val_pod_args << "." << packGridDimX << ") { return; }\n";
   // the function scope.
   int func_scope = this->BeginScope();
   this->PrintStmt(f->body);
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 953c4eb774..fe128421c7 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -576,6 +576,35 @@ export class WebGPUContext {
 
         assert(args.length == numBufferOrPodArgs + dispatchToDim.length);
 
+        const workDim: Array<number> = [1, 1, 1, 1, 1, 1];
+        for (let i = 0; i < dispatchToDim.length; ++i) {
+          workDim[dispatchToDim[i]] = args[numBufferOrPodArgs + i];
+        }
+
+        // get around 65535 restriction of blockIdx.x
+        if (workDim[2] != 1) {
+          throw Error("WebGPU: blockIdx.z is reserved for internal use");
+        }
+        const packDimX = workDim[0];
+        // spread thinsg out into blockIdx.z
+        if (workDim[0] >= (1 << 16)) {
+          let wl_x = workDim[0];
+          let wl_z = workDim[2];
+
+          while (wl_x >= (1 << 16)) {
+            if (wl_x % 2 == 0) {
+              wl_x = wl_x / 2;
+            } else {
+              // pad up
+              wl_x = (wl_x + 1) / 2;
+            }
+            wl_z *= 2;
+          }
+          workDim[0] = wl_x;
+          workDim[2] = wl_z;
+          assert(wl_x * wl_z >= packDimX);
+        }
+
         for (let i = 0; i < bufferArgIndices.length; ++i) {
           bindGroupEntries.push({
             binding: i,
@@ -588,8 +617,8 @@ export class WebGPUContext {
         // push pod buffer
         if (podArgIndices.length != 0) {
           const sizeOfI32 = 4;
-          const podArgBuffer = this.getPodArgsBuffer(podArgIndices.length * 
sizeOfI32);
-          const i32View = new Int32Array(podArgIndices.length);
+          const podArgBuffer = this.getPodArgsBuffer((podArgIndices.length + 
1) * sizeOfI32);
+          const i32View = new Int32Array(podArgIndices.length + 1);
           const u32View = new Uint32Array(i32View.buffer);
           const f32View = new Float32Array(i32View.buffer);
 
@@ -606,6 +635,8 @@ export class WebGPUContext {
               throw Error("Unknown pod dtype " + dtype);
             }
           }
+          // always pass in dim z launching grid size in
+          u32View[podArgIndices.length] = packDimX;
           this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer);
 
           bindGroupEntries.push({
@@ -621,31 +652,8 @@ export class WebGPUContext {
           layout: bindGroupLayout,
           entries: bindGroupEntries
         }));
-        const wl: Array<number> = [1, 1, 1, 1, 1, 1];
-        for (let i = 0; i < dispatchToDim.length; ++i) {
-          wl[dispatchToDim[i]] = args[numBufferOrPodArgs + i];
-        }
 
-        // get around 65535 restriction of blockIdx.x
-        if (wl[2] != 1) {
-          throw Error("WebGPU: blockIdx.z is reserved for internal use");
-        }
-        // spread thinsg out into blockIdx.z
-        if (wl[0] >= (1 << 16)) {
-          let wl_x = wl[0];
-          let wl_z = wl[2];
-
-          while (wl_x >= (1 << 16)) {
-            if (wl_x % 2 != 0) {
-              throw Error("WebGPU: cannot factorize big gridDim.x=" + 
wl[0].toString());
-            }
-            wl_x /= 2;
-            wl_z *= 2;
-          }
-          wl[0] = wl_x;
-          wl[2] = wl_z;
-        }
-        compute.dispatchWorkgroups(wl[0], wl[1], wl[2])
+        compute.dispatchWorkgroups(workDim[0], workDim[1], workDim[2])
         compute.end()
         const command = commandEncoder.finish();
         this.device.queue.submit([command]);
diff --git a/web/tests/python/webgpu_rpc_test.py 
b/web/tests/python/webgpu_rpc_test.py
index 31dd8fc043..9c26c2397a 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/webgpu_rpc_test.py
@@ -44,7 +44,7 @@ def test_rpc():
     mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]))
     sch = tvm.tir.Schedule(mod)
     (i,) = sch.get_loops(block=sch.get_block("B"))
-    i0, i1 = sch.split(i, [None, 128])
+    i0, i1 = sch.split(i, [None, 32])
     sch.bind(i0, "blockIdx.x")
     sch.bind(i1, "threadIdx.x")
 
@@ -76,7 +76,7 @@ def test_rpc():
         np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), 
atol=1e-5, rtol=1e-5)
         print("Test pass..")
 
-    check(remote, 2049)
+    check(remote, 71821 * 32)
 
 
 test_rpc()

Reply via email to