This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f27aa8db7 [WebGPU][CodeGen] Override PrintVecElemLoad and Store for
WebGPU (#17917)
3f27aa8db7 is described below
commit 3f27aa8db7daf3b1e286c88b616274b6703319b1
Author: Charlie Ruan <[email protected]>
AuthorDate: Sat May 3 22:45:24 2025 -0400
[WebGPU][CodeGen] Override PrintVecElemLoad and Store for WebGPU (#17917)
This PR overrides `PrintVecElemLoad()` and `PrintVecElemStore()`
for the WebGPU backend.
Otherwise, we would generate things like `(QK_local[0i].s0)` for
WebGPU, which is not a valid syntax in WGSL.
Instead, we generate `(QK_local[0i][0])` after this PR. `QK_local` here
is a `array<vec4<f32>, 1>`.
This issue prevented WebLLM from generating the correct kernel
after https://github.com/apache/tvm/pull/17748
Co-authored-by: Ruihang Lai <[email protected]>
---
src/target/source/codegen_webgpu.cc | 11 +++++++++++
src/target/source/codegen_webgpu.h | 4 ++++
2 files changed, 15 insertions(+)
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 1d1df91dc4..90be766638 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -348,6 +348,17 @@ void CodeGenWebGPU::PrintSSAAssign(const std::string&
target, const std::string&
stream << " = " << src << ";\n";
}
+void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DataType t, int i,
+ std::ostream& os) { // NOLINT(*)
+ os << vec << "[" << i << "]";
+}
+
+void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int
i,
+ const std::string& value) {
+ this->PrintIndent();
+ stream << vec << "[" << i << "] = " << value << ";\n";
+}
+
void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
// NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
diff --git a/src/target/source/codegen_webgpu.h
b/src/target/source/codegen_webgpu.h
index 09f99fb886..b8f2f9a79d 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -58,6 +58,10 @@ class CodeGenWebGPU final : public CodeGenC {
// assignment printing
void PrintSSAAssign(const std::string& target, const std::string& src,
DataType type) final;
+ // overload printing vector element load/store
+ void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final;
+ void PrintVecElemStore(const std::string& vec, DataType t, int i, const
std::string& value) final;
+
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; //
NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; //
NOLINT(*)