This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3e08e702fa [WebGPU] Implement `tir.dp4a` with WGSL built-in function
`dot4I8Packed` (#16976)
3e08e702fa is described below
commit 3e08e702fa27b51a948792d467a7734cd6995cf4
Author: Jiawei Shao <[email protected]>
AuthorDate: Fri Jul 5 02:03:56 2024 +0800
[WebGPU] Implement `tir.dp4a` with WGSL built-in function `dot4I8Packed`
(#16976)
* [WebGPU] Support `__dp4a(int8x4, int8x4)` as a pure extern method
This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.
Here is an example to use `__dp4a` in WebGPU target:
```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32",
"__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```
Issue: #16627
* Add validation
* Add `dot4I8Packed` to WebGPU lower intrinsic
* Implement builtin `dp4a` with `dot4I8Packed`
* Small fix
* Add missing comment
---
src/target/source/codegen_webgpu.cc | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index a95f6e0fa0..b76b05470d 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -410,6 +410,14 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLIN
this->EndScope(else_scope);
}
os << result;
+ } else if (op->op.same_as(builtin::dp4a())) {
+ // generate `dot4I8Packed(vec1, vec2) + acc` for the builtin `dp4a`
+ os << "dot4I8Packed(";
+ this->PrintExpr(op->args[0], os);
+ os << ", ";
+ this->PrintExpr(op->args[1], os);
+ os << ") + ";
+ this->PrintExpr(op->args[2], os);
} else {
CodeGenC::VisitExpr_(op, os);
}