This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a6831ba9c4 [Unity] Enable pod args in WebGPU (#14560)
a6831ba9c4 is described below
commit a6831ba9c4cde5a14879cc2b3911e81977beecd4
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 10 14:13:56 2023 -0400
[Unity] Enable pod args in WebGPU (#14560)
This PR adds POD argument support in webgpu.
---
python/tvm/exec/rpc_proxy.py | 2 +-
python/tvm/rpc/proxy.py | 3 +-
src/target/source/codegen_webgpu.cc | 46 +++++++++++--
web/src/runtime.ts | 3 +-
web/src/webgpu.ts | 130 +++++++++++++++++++++++++++++++-----
web/tests/python/webgpu_rpc_test.py | 28 ++++----
6 files changed, 170 insertions(+), 42 deletions(-)
diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index 8cf1e4010b..eaa406906a 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -35,7 +35,7 @@ def find_example_resource():
("/", os.path.join(base_path, "web", "dist", "wasm",
"tvmjs_runtime.wasi.js")),
("/", index_page),
]
- allow_format = ("json", "bin", "js", "wasm", "html")
+ allow_format = ("json", "bin", "js", "wasm", "html", "css")
# recursively apend things in www, up to two levels
resource_bases = [
diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py
index 59af53d4e1..a6709a3421 100644
--- a/python/tvm/rpc/proxy.py
+++ b/python/tvm/rpc/proxy.py
@@ -204,7 +204,8 @@ class WebSocketHandler(websocket.WebSocketHandler,
ForwardHandler):
MIME_MAP = {
- "js": "application/javascript",
+ "js": "text/javascript",
+ "css": "text/css",
"wasm": "application/wasm",
"json": "application/json",
}
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 32b3206394..188aa12f3a 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -146,8 +146,9 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f, bool skip_re
// setup buffer argumemts
for (Var arg : f->params) {
DataType t = arg.dtype();
+ func_info.arg_types.push_back(t);
+
if (t.is_handle()) {
- func_info.arg_types.push_back(t);
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a
type_annotation as a "
"PointerType, "
@@ -184,6 +185,43 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f, bool skip_re
}
}
+ // Store all pod arguments in a single buffer of int32
+ // do bitcast to change to other data types
+ if (pod_args.size() != 0) {
+ std::string type_pod_args = name_supply_->FreshName("PODArgs");
+ std::string val_pod_args = name_supply_->FreshName("podArgs");
+
+ this->decl_stream << "\nstruct " << type_pod_args << " {\n";
+
+ for (size_t i = 0; i < pod_args.size(); ++i) {
+ Var v = pod_args[i];
+ ICHECK(!v.dtype().is_handle());
+ std::string vid = AllocVarID(v.get());
+
+ if (v.dtype() == DataType::Int(32)) {
+ this->decl_stream << " " << vid << ": i32";
+ } else if (v.dtype() == DataType::UInt(32)) {
+ this->decl_stream << " " << vid << ": u32";
+ } else if (v.dtype() == DataType::Float(32)) {
+ this->decl_stream << " " << vid << ": f32";
+ } else {
+ LOG(FATAL) << "Do not support pod argument type " << v.dtype();
+ }
+ if (i + 1 != pod_args.size()) {
+ this->decl_stream << ",\n";
+ } else {
+ this->decl_stream << "\n}\n";
+ }
+ // value ref
+ std::ostringstream vref;
+ vref << val_pod_args << "." << vid;
+ var_idmap_[v.get()] = vref.str();
+ }
+
+ this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
+ << "var<uniform> " << val_pod_args << " : " <<
type_pod_args << ";\n\n";
+ }
+
// setup thread tags and param access in launch param tags;
if (auto opt =
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
auto thread_axis = opt.value();
@@ -194,12 +232,6 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const
PrimFunc& f, bool skip_re
os_param_access << "]";
func_info.launch_param_tags.push_back(os_param_access.str());
- if (pod_args.size() != 0) {
- // setup POD arguments
- // TODO(tvm-team): store as a uniform, readonly buffer.
- LOG(FATAL) << "Do not support pod arguments for now";
- }
-
ICHECK(!info.has_block_index_z)
<< "blockIdx.z is not supported in WebGPU to accomodate large
blockIdx.x";
// anotate workgroup
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index f3a6029bbe..427bb2d8f8 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -67,6 +67,7 @@ class FFILibrary implements Disposable {
while (this.recycledCallStacks.length != 0) {
(this.recycledCallStacks.pop() as Disposable).dispose();
}
+ this.webGPUContext?.dispose();
}
sizeofPtr(): number {
@@ -1031,8 +1032,6 @@ export class Instance implements Disposable {
}
dispose(): void {
- // dispose canvas resource
- this.lib.webGPUContext?.disposeCanvas();
// order matters
// ctx release goes back into lib.
this.ctx.dispose();
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index ac39595c76..c68c42520f 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -285,7 +285,10 @@ export class WebGPUContext {
// internal data
private bufferTable: Array<GPUBuffer | undefined> = [undefined];
private bufferTableFreeId: Array<number> = [];
+ private podArgStagingBuffers: Array<GPUBuffer> = [];
private canvasRenderManager?: CanvaRenderManager = undefined;
+ // number of pod arg staging buffers
+ private maxNumPodArgsStagingBuffers: number = 2;
// flags for debugging
// stats of the runtime.
// peak allocation
@@ -307,18 +310,25 @@ export class WebGPUContext {
}
/**
- * Wait for all pending GPU tasks to complete
+ * Dispose context.
*/
- async sync(): Promise<void> {
- await this.device.queue.onSubmittedWorkDone();
+ dispose() {
+ this.canvasRenderManager?.dispose();
+ this.bufferTableFreeId = [];
+ while (this.bufferTable.length != 0) {
+ this.bufferTable.pop()?.destroy();
+ }
+ while (this.podArgStagingBuffers.length != 0) {
+ this.podArgStagingBuffers.pop()?.destroy();
+ }
+ this.device.destroy();
}
/**
- * Dispose the binded canvas.
+ * Wait for all pending GPU tasks to complete
*/
- disposeCanvas() {
- this.canvasRenderManager?.dispose();
- this.canvasRenderManager = undefined;
+ async sync(): Promise<void> {
+ await this.device.queue.onSubmittedWorkDone();
}
/**
@@ -391,7 +401,7 @@ export class WebGPUContext {
* @returns The shader
*/
createShader(finfo: FunctionInfo, code: string) : Function {
- return this.createShadeInternl(finfo, code, false) as Function;
+ return this.createShadeInternal(finfo, code, false) as Function;
}
/**
@@ -403,7 +413,41 @@ export class WebGPUContext {
* @returns The shader
*/
async createShaderAsync(finfo: FunctionInfo, code: string) :
Promise<Function> {
- return await (this.createShadeInternl(finfo, code, true) as
Promise<Function>);
+ return await (this.createShadeInternal(finfo, code, true) as
Promise<Function>);
+ }
+
+ /**
+ * Get the pod arg staging buffer
+ * \param nbytes The minimum size.
+ * \return The allocated buffer
+ */
+ private getPodArgsBuffer(nbytes: number) : GPUBuffer {
+ let buffer : GPUBuffer | undefined = undefined;
+ if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) {
+ buffer = this.podArgStagingBuffers.shift();
+ }
+ // minimum of 16 bytes
+ let allocSize = 16;
+ if (buffer !== undefined) {
+ allocSize = buffer.size;
+ if (buffer.size < nbytes) {
+ buffer.destroy();
+ buffer = undefined;
+ }
+ }
+ while (allocSize < nbytes) {
+ allocSize *= 2;
+ }
+
+ if (buffer == undefined) {
+ // create uniform buffer
+ buffer = this.device.createBuffer({
+ size: allocSize,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ }
+ assert(nbytes <= buffer.size);
+ return buffer;
}
/**
@@ -414,7 +458,7 @@ export class WebGPUContext {
* @param asyncMode Whether use async mode.
* @returns The shader function or promise of shader func.
*/
- private createShadeInternl(
+ private createShadeInternal(
finfo: FunctionInfo,
code: string,
asyncMode: boolean
@@ -439,23 +483,41 @@ export class WebGPUContext {
}
}
- assert(paramWriteAccess.length == finfo.arg_types.length);
const layoutEntries: Array<GPUBindGroupLayoutEntry> = [];
+ const bufferArgIndices : Array<number> = [];
+ const podArgIndices : Array<number> = [];
+
for (let i = 0; i < finfo.arg_types.length; ++i) {
const dtype = finfo.arg_types[i];
if (dtype == "handle") {
layoutEntries.push({
- binding: i,
+ binding: bufferArgIndices.length,
visibility: GPUShaderStage.COMPUTE,
buffer : {
- type: paramWriteAccess[i] ? "storage" : "read-only-storage"
+ type: paramWriteAccess[bufferArgIndices.length] ? "storage" :
"read-only-storage"
}
});
+ bufferArgIndices.push(i);
+ } else if (dtype.startsWith("int") || dtype.startsWith("uint") ||
dtype.startsWith("float")) {
+ podArgIndices.push(i);
} else {
throw new Error("Cannot handle argument type " + dtype + " in WebGPU
shader");
}
}
+
+ assert(paramWriteAccess.length == bufferArgIndices.length);
+ // POD arguments are pass in the end
+ if (podArgIndices.length != 0) {
+ layoutEntries.push({
+ binding: bufferArgIndices.length,
+ visibility: GPUShaderStage.COMPUTE,
+ buffer : {
+ type: "uniform"
+ }
+ });
+ }
+
const bindGroupLayout = this.device.createBindGroupLayout({
entries: layoutEntries
});
@@ -476,13 +538,47 @@ export class WebGPUContext {
const compute = commandEncoder.beginComputePass();
compute.setPipeline(pipeline);
const bindGroupEntries: Array<GPUBindGroupEntry> = [];
- assert(args.length == layoutEntries.length + dispatchToDim.length);
+ const numBufferOrPodArgs = bufferArgIndices.length +
podArgIndices.length;
+
+ assert(args.length == numBufferOrPodArgs + dispatchToDim.length);
- for (let i = 0; i < layoutEntries.length; ++i) {
+ for (let i = 0; i < bufferArgIndices.length; ++i) {
bindGroupEntries.push({
binding: i,
resource: {
- buffer: this.gpuBufferFromPtr(args[i])
+ buffer: this.gpuBufferFromPtr(args[bufferArgIndices[i]])
+ }
+ });
+ }
+
+ // push pod buffer
+ if (podArgIndices.length != 0) {
+ const sizeOfI32 = 4;
+ const podArgBuffer = this.getPodArgsBuffer(podArgIndices.length *
sizeOfI32);
+ const i32View = new Int32Array(podArgIndices.length);
+ const u32View = new Uint32Array(i32View.buffer);
+ const f32View = new Float32Array(i32View.buffer);
+
+ for (let i = 0; i < podArgIndices.length; ++i) {
+ const value = args[podArgIndices[i]];
+ const dtype = finfo.arg_types[podArgIndices[i]];
+ if (dtype.startsWith("int")) {
+ i32View[i] = value;
+ } else if (dtype.startsWith("uint")) {
+ u32View[i] = value;
+ } else if (dtype.startsWith("float")) {
+ f32View[i] = value;
+ } else {
+ throw Error("Unknown pod dtype " + dtype);
+ }
+ }
+ this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer);
+
+ bindGroupEntries.push({
+ binding: bufferArgIndices.length,
+ resource: {
+ buffer: podArgBuffer,
+ size: i32View.buffer.byteLength
}
});
}
@@ -493,7 +589,7 @@ export class WebGPUContext {
}));
const wl: Array<number> = [1, 1, 1, 1, 1, 1];
for (let i = 0; i < dispatchToDim.length; ++i) {
- wl[dispatchToDim[i]] = args[layoutEntries.length + i];
+ wl[dispatchToDim[i]] = args[numBufferOrPodArgs + i];
}
// get around 65535 restriction of blockIdx.x
diff --git a/web/tests/python/webgpu_rpc_test.py
b/web/tests/python/webgpu_rpc_test.py
index 986393e9d4..31dd8fc043 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/webgpu_rpc_test.py
@@ -38,17 +38,17 @@ def test_rpc():
target = tvm.target.Target("webgpu", host="llvm
-mtriple=wasm32-unknown-unknown-wasm")
runtime = Runtime("cpp", {"system-lib": True})
- n = 2048
+ n = te.var("n")
A = te.placeholder((n,), name="A")
- B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i)) + 1.0), name="B")
- s = te.create_schedule(B.op)
+ B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i) + 1)), name="B")
+ mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]))
+ sch = tvm.tir.Schedule(mod)
+ (i,) = sch.get_loops(block=sch.get_block("B"))
+ i0, i1 = sch.split(i, [None, 128])
+ sch.bind(i0, "blockIdx.x")
+ sch.bind(i1, "threadIdx.x")
- num_thread = 2
- xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
- s[B].bind(xi, te.thread_axis("threadIdx.x"))
- s[B].bind(xo, te.thread_axis("blockIdx.x"))
-
- fadd = tvm.build(s, [A, B], target, runtime=runtime, name="addone")
+ fadd = tvm.build(sch.mod, target=target, runtime=runtime)
temp = utils.tempdir()
wasm_path = temp.relpath("addone_gpu.wasm")
@@ -62,21 +62,21 @@ def test_rpc():
session_constructor_args=["rpc.WasmSession", wasm_binary],
)
- def check(remote):
+ def check(remote, size):
# basic function checks.
dev = remote.webgpu(0)
- adata = np.random.uniform(size=n).astype(A.dtype)
+ adata = np.random.uniform(size=size).astype(A.dtype)
a = tvm.nd.array(adata, dev)
- b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+ b = tvm.nd.array(np.zeros(size, dtype=A.dtype), dev)
np.testing.assert_equal(a.numpy(), adata)
f1 = remote.system_lib()
- addone = f1.get_function("addone")
+ addone = f1.get_function("main")
addone(a, b)
np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1),
atol=1e-5, rtol=1e-5)
print("Test pass..")
- check(remote)
+ check(remote, 2049)
test_rpc()