This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a6831ba9c4 [Unity] Enable pod args in WebGPU (#14560)
a6831ba9c4 is described below

commit a6831ba9c4cde5a14879cc2b3911e81977beecd4
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 10 14:13:56 2023 -0400

    [Unity] Enable pod args in WebGPU (#14560)
    
    This PR adds POD argument support in webgpu.
---
 python/tvm/exec/rpc_proxy.py        |   2 +-
 python/tvm/rpc/proxy.py             |   3 +-
 src/target/source/codegen_webgpu.cc |  46 +++++++++++--
 web/src/runtime.ts                  |   3 +-
 web/src/webgpu.ts                   | 130 +++++++++++++++++++++++++++++++-----
 web/tests/python/webgpu_rpc_test.py |  28 ++++----
 6 files changed, 170 insertions(+), 42 deletions(-)

diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index 8cf1e4010b..eaa406906a 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -35,7 +35,7 @@ def find_example_resource():
         ("/", os.path.join(base_path, "web", "dist", "wasm", 
"tvmjs_runtime.wasi.js")),
         ("/", index_page),
     ]
-    allow_format = ("json", "bin", "js", "wasm", "html")
+    allow_format = ("json", "bin", "js", "wasm", "html", "css")
 
     # recursively apend things in www, up to two levels
     resource_bases = [
diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py
index 59af53d4e1..a6709a3421 100644
--- a/python/tvm/rpc/proxy.py
+++ b/python/tvm/rpc/proxy.py
@@ -204,7 +204,8 @@ class WebSocketHandler(websocket.WebSocketHandler, 
ForwardHandler):
 
 
 MIME_MAP = {
-    "js": "application/javascript",
+    "js": "text/javascript",
+    "css": "text/css",
     "wasm": "application/wasm",
     "json": "application/json",
 }
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index 32b3206394..188aa12f3a 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -146,8 +146,9 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f, bool skip_re
   // setup buffer argumemts
   for (Var arg : f->params) {
     DataType t = arg.dtype();
+    func_info.arg_types.push_back(t);
+
     if (t.is_handle()) {
-      func_info.arg_types.push_back(t);
       auto* ptr = arg->type_annotation.as<PointerTypeNode>();
       ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a 
type_annotation as a "
                      "PointerType, "
@@ -184,6 +185,43 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f, bool skip_re
     }
   }
 
+  // Store all pod arguments in a single buffer of int32
+  // do bitcast to change to other data types
+  if (pod_args.size() != 0) {
+    std::string type_pod_args = name_supply_->FreshName("PODArgs");
+    std::string val_pod_args = name_supply_->FreshName("podArgs");
+
+    this->decl_stream << "\nstruct " << type_pod_args << " {\n";
+
+    for (size_t i = 0; i < pod_args.size(); ++i) {
+      Var v = pod_args[i];
+      ICHECK(!v.dtype().is_handle());
+      std::string vid = AllocVarID(v.get());
+
+      if (v.dtype() == DataType::Int(32)) {
+        this->decl_stream << "  " << vid << ": i32";
+      } else if (v.dtype() == DataType::UInt(32)) {
+        this->decl_stream << "  " << vid << ": u32";
+      } else if (v.dtype() == DataType::Float(32)) {
+        this->decl_stream << "  " << vid << ": f32";
+      } else {
+        LOG(FATAL) << "Do not support pod argument type " << v.dtype();
+      }
+      if (i + 1 != pod_args.size()) {
+        this->decl_stream << ",\n";
+      } else {
+        this->decl_stream << "\n}\n";
+      }
+      // value ref
+      std::ostringstream vref;
+      vref << val_pod_args << "." << vid;
+      var_idmap_[v.get()] = vref.str();
+    }
+
+    this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
+                      << "var<uniform> " << val_pod_args << " : " << 
type_pod_args << ";\n\n";
+  }
+
   // setup thread tags and param access in launch param tags;
   if (auto opt = 
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
     auto thread_axis = opt.value();
@@ -194,12 +232,6 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const 
PrimFunc& f, bool skip_re
   os_param_access << "]";
   func_info.launch_param_tags.push_back(os_param_access.str());
 
-  if (pod_args.size() != 0) {
-    // setup POD arguments
-    // TODO(tvm-team): store as a uniform, readonly buffer.
-    LOG(FATAL) << "Do not support pod arguments for now";
-  }
-
   ICHECK(!info.has_block_index_z)
       << "blockIdx.z is not supported in WebGPU to accomodate large 
blockIdx.x";
   // anotate workgroup
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index f3a6029bbe..427bb2d8f8 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -67,6 +67,7 @@ class FFILibrary implements Disposable {
     while (this.recycledCallStacks.length != 0) {
       (this.recycledCallStacks.pop() as Disposable).dispose();
     }
+    this.webGPUContext?.dispose();
   }
 
   sizeofPtr(): number {
@@ -1031,8 +1032,6 @@ export class Instance implements Disposable {
   }
 
   dispose(): void {
-    // dispose canvas resource
-    this.lib.webGPUContext?.disposeCanvas();
     // order matters
     // ctx release goes back into lib.
     this.ctx.dispose();
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index ac39595c76..c68c42520f 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -285,7 +285,10 @@ export class WebGPUContext {
   // internal data
   private bufferTable: Array<GPUBuffer | undefined> = [undefined];
   private bufferTableFreeId: Array<number> = [];
+  private podArgStagingBuffers: Array<GPUBuffer> = [];
   private canvasRenderManager?: CanvaRenderManager = undefined;
+  // number of pod arg staging buffers
+  private maxNumPodArgsStagingBuffers: number = 2;
   // flags for debugging
   // stats of the runtime.
   // peak allocation
@@ -307,18 +310,25 @@ export class WebGPUContext {
   }
 
   /**
-   * Wait for all pending GPU tasks to complete
+   * Dispose context.
    */
-  async sync(): Promise<void> {
-    await this.device.queue.onSubmittedWorkDone();
+  dispose() {
+    this.canvasRenderManager?.dispose();
+    this.bufferTableFreeId = [];
+    while (this.bufferTable.length != 0) {
+      this.bufferTable.pop()?.destroy();
+    }
+    while (this.podArgStagingBuffers.length != 0) {
+      this.podArgStagingBuffers.pop()?.destroy();
+    }
+    this.device.destroy();
   }
 
   /**
-   * Dispose the binded canvas.
+   * Wait for all pending GPU tasks to complete
    */
-  disposeCanvas() {
-    this.canvasRenderManager?.dispose();
-    this.canvasRenderManager = undefined;
+  async sync(): Promise<void> {
+    await this.device.queue.onSubmittedWorkDone();
   }
 
   /**
@@ -391,7 +401,7 @@ export class WebGPUContext {
    * @returns The shader
    */
   createShader(finfo: FunctionInfo, code: string) : Function {
-    return this.createShadeInternl(finfo, code, false) as Function;
+    return this.createShadeInternal(finfo, code, false) as Function;
   }
 
   /**
@@ -403,7 +413,41 @@ export class WebGPUContext {
    * @returns The shader
    */
   async createShaderAsync(finfo: FunctionInfo, code: string) : 
Promise<Function> {
-    return await (this.createShadeInternl(finfo, code, true) as 
Promise<Function>);
+    return await (this.createShadeInternal(finfo, code, true) as 
Promise<Function>);
+  }
+
+  /**
+   * Get the pod arg staging buffer
+   * \param nbytes The minimum size.
+   * \return The allocated buffer
+   */
+  private getPodArgsBuffer(nbytes: number) : GPUBuffer {
+    let buffer : GPUBuffer | undefined = undefined;
+    if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) {
+      buffer = this.podArgStagingBuffers.shift();
+    }
+    // minimum of 16 bytes
+    let allocSize = 16;
+    if (buffer !== undefined) {
+      allocSize = buffer.size;
+      if (buffer.size < nbytes) {
+        buffer.destroy();
+        buffer = undefined;
+      }
+    }
+    while (allocSize < nbytes) {
+      allocSize *= 2;
+    }
+
+    if (buffer == undefined) {
+      // create uniform buffer
+      buffer = this.device.createBuffer({
+        size: allocSize,
+        usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+      });
+    }
+    assert(nbytes <= buffer.size);
+    return buffer;
   }
 
   /**
@@ -414,7 +458,7 @@ export class WebGPUContext {
    * @param asyncMode Whether use async mode.
    * @returns The shader function or promise of shader func.
    */
-  private createShadeInternl(
+  private createShadeInternal(
     finfo: FunctionInfo,
     code: string,
     asyncMode: boolean
@@ -439,23 +483,41 @@ export class WebGPUContext {
       }
     }
 
-    assert(paramWriteAccess.length == finfo.arg_types.length);
 
     const layoutEntries: Array<GPUBindGroupLayoutEntry> = [];
+    const bufferArgIndices : Array<number> = [];
+    const podArgIndices : Array<number> = [];
+
     for (let i = 0; i < finfo.arg_types.length; ++i) {
       const dtype = finfo.arg_types[i];
       if (dtype == "handle") {
         layoutEntries.push({
-          binding: i,
+          binding: bufferArgIndices.length,
           visibility: GPUShaderStage.COMPUTE,
           buffer :  {
-            type: paramWriteAccess[i] ? "storage" : "read-only-storage"
+            type: paramWriteAccess[bufferArgIndices.length] ? "storage" : 
"read-only-storage"
           }
         });
+        bufferArgIndices.push(i);
+      } else if (dtype.startsWith("int") || dtype.startsWith("uint") || 
dtype.startsWith("float")) {
+        podArgIndices.push(i);
       } else {
         throw new Error("Cannot handle argument type " + dtype + " in WebGPU 
shader");
       }
     }
+
+    assert(paramWriteAccess.length == bufferArgIndices.length);
+    // POD arguments are pass in the end
+    if (podArgIndices.length != 0) {
+      layoutEntries.push({
+        binding: bufferArgIndices.length,
+        visibility: GPUShaderStage.COMPUTE,
+        buffer :  {
+          type: "uniform"
+        }
+      });
+    }
+
     const bindGroupLayout = this.device.createBindGroupLayout({
       entries: layoutEntries
     });
@@ -476,13 +538,47 @@ export class WebGPUContext {
         const compute = commandEncoder.beginComputePass();
         compute.setPipeline(pipeline);
         const bindGroupEntries: Array<GPUBindGroupEntry> = [];
-        assert(args.length == layoutEntries.length + dispatchToDim.length);
+        const numBufferOrPodArgs = bufferArgIndices.length + 
podArgIndices.length;
+
+        assert(args.length == numBufferOrPodArgs + dispatchToDim.length);
 
-        for (let i = 0; i < layoutEntries.length; ++i) {
+        for (let i = 0; i < bufferArgIndices.length; ++i) {
           bindGroupEntries.push({
             binding: i,
             resource: {
-              buffer: this.gpuBufferFromPtr(args[i])
+              buffer: this.gpuBufferFromPtr(args[bufferArgIndices[i]])
+            }
+          });
+        }
+
+        // push pod buffer
+        if (podArgIndices.length != 0) {
+          const sizeOfI32 = 4;
+          const podArgBuffer = this.getPodArgsBuffer(podArgIndices.length * 
sizeOfI32);
+          const i32View = new Int32Array(podArgIndices.length);
+          const u32View = new Uint32Array(i32View.buffer);
+          const f32View = new Float32Array(i32View.buffer);
+
+          for (let i = 0; i < podArgIndices.length; ++i) {
+            const value = args[podArgIndices[i]];
+            const dtype = finfo.arg_types[podArgIndices[i]];
+            if (dtype.startsWith("int")) {
+              i32View[i] = value;
+            } else if (dtype.startsWith("uint")) {
+              u32View[i] = value;
+            } else if (dtype.startsWith("float")) {
+              f32View[i] = value;
+            } else {
+              throw Error("Unknown pod dtype " + dtype);
+            }
+          }
+          this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer);
+
+          bindGroupEntries.push({
+            binding: bufferArgIndices.length,
+            resource: {
+              buffer: podArgBuffer,
+              size: i32View.buffer.byteLength
             }
           });
         }
@@ -493,7 +589,7 @@ export class WebGPUContext {
         }));
         const wl: Array<number> = [1, 1, 1, 1, 1, 1];
         for (let i = 0; i < dispatchToDim.length; ++i) {
-          wl[dispatchToDim[i]] = args[layoutEntries.length + i];
+          wl[dispatchToDim[i]] = args[numBufferOrPodArgs + i];
         }
 
         // get around 65535 restriction of blockIdx.x
diff --git a/web/tests/python/webgpu_rpc_test.py 
b/web/tests/python/webgpu_rpc_test.py
index 986393e9d4..31dd8fc043 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/webgpu_rpc_test.py
@@ -38,17 +38,17 @@ def test_rpc():
     target = tvm.target.Target("webgpu", host="llvm 
-mtriple=wasm32-unknown-unknown-wasm")
     runtime = Runtime("cpp", {"system-lib": True})
 
-    n = 2048
+    n = te.var("n")
     A = te.placeholder((n,), name="A")
-    B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i)) + 1.0), name="B")
-    s = te.create_schedule(B.op)
+    B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i) + 1)), name="B")
+    mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]))
+    sch = tvm.tir.Schedule(mod)
+    (i,) = sch.get_loops(block=sch.get_block("B"))
+    i0, i1 = sch.split(i, [None, 128])
+    sch.bind(i0, "blockIdx.x")
+    sch.bind(i1, "threadIdx.x")
 
-    num_thread = 2
-    xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
-    s[B].bind(xi, te.thread_axis("threadIdx.x"))
-    s[B].bind(xo, te.thread_axis("blockIdx.x"))
-
-    fadd = tvm.build(s, [A, B], target, runtime=runtime, name="addone")
+    fadd = tvm.build(sch.mod, target=target, runtime=runtime)
     temp = utils.tempdir()
 
     wasm_path = temp.relpath("addone_gpu.wasm")
@@ -62,21 +62,21 @@ def test_rpc():
         session_constructor_args=["rpc.WasmSession", wasm_binary],
     )
 
-    def check(remote):
+    def check(remote, size):
         # basic function checks.
         dev = remote.webgpu(0)
-        adata = np.random.uniform(size=n).astype(A.dtype)
+        adata = np.random.uniform(size=size).astype(A.dtype)
         a = tvm.nd.array(adata, dev)
-        b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+        b = tvm.nd.array(np.zeros(size, dtype=A.dtype), dev)
 
         np.testing.assert_equal(a.numpy(), adata)
         f1 = remote.system_lib()
-        addone = f1.get_function("addone")
+        addone = f1.get_function("main")
         addone(a, b)
         np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), 
atol=1e-5, rtol=1e-5)
         print("Test pass..")
 
-    check(remote)
+    check(remote, 2049)
 
 
 test_rpc()

Reply via email to