This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b04b1acf40 [Web] Compatibility with PagedKVCache in WebGPU (#16554)
b04b1acf40 is described below
commit b04b1acf409131a3880a3a4c7824d70e83b13456
Author: Charlie Ruan <[email protected]>
AuthorDate: Mon Feb 12 13:52:07 2024 -0500
[Web] Compatibility with PagedKVCache in WebGPU (#16554)
This PR introduces various WebGPU changes to accommodate the new
`PagedKVCache` interface. All changes below are essential for making models
that use PagedKVCache runnable under WebGPU:
- Require exactly same-dtype matching for WebGPU smem reuse in
`storage_rewrite.cc`
- Rename `AttentionKVCache` to `AttentionKVCacheLegacy` for the old KVcache
interface in `lm_support.cc`; include `paged_kv_cache.cc` when making
`wasm_runtime` subsequently
- In WebGPU codegen:
- Declare local variables within the function scope rather than the
module scope
- Generate `while (true)` rather than `while (1)`
- Require 10 `maxStorageBuffersPerShaderStage` rather than the default 8
from the WebGPU device when initializing runtime; this is required for new
kernels introduced in PagedKVCache
- In `deviceCopyToCPU()`, when raw bytes to write are not multiples of 4,
we pad them, as required by WebGPU's `writeBuffer()`.
---------
Co-authored-by: Rick Zhou <[email protected]>
---
src/runtime/relax_vm/lm_support.cc | 40 +++++++++++-----------
src/target/source/codegen_webgpu.cc | 29 ++++++++++++----
src/target/source/codegen_webgpu.h | 1 +
src/tir/transforms/storage_rewrite.cc | 5 +--
web/emcc/wasm_runtime.cc | 1 +
web/src/webgpu.ts | 63 ++++++++++++++++++++++-------------
6 files changed, 88 insertions(+), 51 deletions(-)
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index ecaacb7770..fccff2cecd 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -226,18 +226,19 @@ class AttentionKVCacheObj : public Object {
}
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
- static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
+ static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
};
/*! \brief reference to closure. */
-class AttentionKVCache : public ObjectRef {
+class AttentionKVCacheLegacy : public ObjectRef {
public:
/*!
* \brief Create the attention kv cache.
* \param init_data The initial reserved.
*/
- static AttentionKVCache Create(NDArray init_data, ShapeTuple reserve_shape,
int init_fill_count) {
+ static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple
reserve_shape,
+ int init_fill_count) {
auto n = make_object<AttentionKVCacheObj>();
n->data = NDArray::Empty(reserve_shape, init_data->dtype,
init_data->device);
n->fill_count = 0;
@@ -246,10 +247,10 @@ class AttentionKVCache : public ObjectRef {
n->fill_count = init_fill_count;
n->window_attention_current_pos = init_fill_count; // window attention
only
}
- return AttentionKVCache(n);
+ return AttentionKVCacheLegacy(n);
}
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, ObjectRef,
AttentionKVCacheObj);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
AttentionKVCacheObj);
};
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
@@ -258,24 +259,24 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
// Register runtime functions
//-------------------------------------------------
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create")
- .set_body_typed(AttentionKVCache::Create);
+ .set_body_typed(AttentionKVCacheLegacy::Create);
-AttentionKVCache AttentionKVCacheUpdate(AttentionKVCache cache, NDArray value)
{
+AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache,
NDArray value) {
cache->Update(value);
return cache;
}
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update").set_body_typed(AttentionKVCacheUpdate);
-AttentionKVCache AttentionKVCacheAppend(AttentionKVCache cache, NDArray value)
{
+AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache,
NDArray value) {
cache->Append(value);
return cache;
}
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append").set_body_typed(AttentionKVCacheAppend);
-AttentionKVCache AttentionKVCacheWindowOverride(AttentionKVCache cache,
NDArray value,
- int64_t max_cache_size) {
+AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy
cache, NDArray value,
+ int64_t max_cache_size) {
cache->WindowOverride(value, max_cache_size);
return cache;
}
@@ -283,9 +284,10 @@ AttentionKVCache
AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
.set_body_typed(AttentionKVCacheWindowOverride);
-AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache
cache, NDArray value,
- int64_t
max_cache_size,
- int64_t
num_attention_sinks) {
+AttentionKVCacheLegacy
AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache,
+ NDArray value,
+ int64_t
max_cache_size,
+ int64_t
num_attention_sinks) {
cache->WindowOverride(value, max_cache_size, num_attention_sinks);
return cache;
}
@@ -293,7 +295,7 @@ AttentionKVCache
AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache cache,
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks")
.set_body_typed(AttentionKVCacheWindowOverrideWithSinks);
-NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
+NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ShapeTuple shape) {
return cache->View(shape);
}
@@ -302,7 +304,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view")
CHECK(args.size() == 1 || args.size() == 2)
<< "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2
arguments, but got "
<< args.size() << ".";
- AttentionKVCache cache = args[0];
+ AttentionKVCacheLegacy cache = args[0];
if (args.size() == 2) {
ShapeTuple shape = args[1];
*rv = cache->View(shape);
@@ -316,8 +318,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view")
}
});
-void AttentionKVCacheArrayPopN(Array<AttentionKVCache> caches, int64_t n) {
- for (AttentionKVCache cache : caches) {
+void AttentionKVCacheArrayPopN(Array<AttentionKVCacheLegacy> caches, int64_t
n) {
+ for (AttentionKVCacheLegacy cache : caches) {
cache->PopN(static_cast<size_t>(n));
}
}
@@ -325,8 +327,8 @@ void AttentionKVCacheArrayPopN(Array<AttentionKVCache>
caches, int64_t n) {
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn")
.set_body_typed(AttentionKVCacheArrayPopN);
-void AttentionKVCacheArrayClear(Array<AttentionKVCache> caches) {
- for (AttentionKVCache cache : caches) {
+void AttentionKVCacheArrayClear(Array<AttentionKVCacheLegacy> caches) {
+ for (AttentionKVCacheLegacy cache : caches) {
cache->Clear();
}
}
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 1702699ac2..5ede16d2f4 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -599,13 +599,15 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
PrintType(op->dtype, this->decl_stream);
this->decl_stream << ", " << constant_size << ">;\n";
} else if (storage_scope.rank == runtime::StorageRank::kLocal) {
- this->decl_stream << "var<private> " << vid << " : array<";
- PrintType(op->dtype, this->decl_stream);
- this->decl_stream << ", " << constant_size << ">;\n";
- // this->PrintIndent();
- // this->stream << "var " << vid << " : array<";
- // PrintType(op->dtype, this->stream);
- // this->stream << ", " << constant_size << ">;\n";
+ // TODO(Charlie): These code would cause non-uniformity as it introduces
variables in module
+ // scope rather than function scope; but it was included for some unknown
reasons; kept for now.
+ // this->decl_stream << "var<private> " << vid << " : array<";
+ // PrintType(op->dtype, this->decl_stream);
+ // this->decl_stream << ", " << constant_size << ">;\n";
+ this->PrintIndent();
+ this->stream << "var " << vid << " : array<";
+ PrintType(op->dtype, this->stream);
+ this->stream << ", " << constant_size << ">;\n";
} else {
LOG(FATAL) << "WebGPU: Do not support storage scope: " <<
storage_scope.to_string();
}
@@ -636,6 +638,19 @@ void CodeGenWebGPU::VisitStmt_(const AllocateConstNode*
op) {
LOG(FATAL) << "WebGPU: do not support alloc const";
}
+void CodeGenWebGPU::VisitStmt_(const WhileNode* op) {
+ PrintIndent();
+ stream << "while (true) {\n";
+ int while_scope = BeginScope();
+ std::string cond = PrintExpr(op->condition);
+ PrintIndent();
+ stream << "if (!(" << cond << ")) { break; }\n";
+ PrintStmt(op->body);
+ this->EndScope(while_scope);
+ PrintIndent();
+ stream << "}\n";
+}
+
//-------------------------------------------------
// WebGPUSourceModule to enable export
//-------------------------------------------------
diff --git a/src/target/source/codegen_webgpu.h
b/src/target/source/codegen_webgpu.h
index f12cd3430d..cf642b9e07 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -73,6 +73,7 @@ class CodeGenWebGPU final : public CodeGenC {
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const AllocateConstNode* op) final;
+ void VisitStmt_(const WhileNode* op) final;
private:
/*!
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index 991c48219b..dd27397f36 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1724,8 +1724,9 @@ Pass StorageRewrite() {
}
Optional<Target> target = f->GetAttr<Target>("target");
- if (target.defined() && target.value()->kind->name == "vulkan") {
- // Require exactly same-dtype matching in smem reuse for Vulkan
+ if (target.defined() &&
+ (target.value()->kind->name == "vulkan" || target.value()->kind->name
== "webgpu")) {
+ // Require exactly same-dtype matching in smem reuse for Vulkan and
WebGPU
reuse_require_exact_matched_dtype = true;
}
auto* n = f.CopyOnWrite();
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 311bbd9971..be9704eaef 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -60,6 +60,7 @@
#include "src/runtime/relax_vm/executable.cc"
#include "src/runtime/relax_vm/lm_support.cc"
#include "src/runtime/relax_vm/ndarray_cache_support.cc"
+#include "src/runtime/relax_vm/paged_kv_cache.cc"
#include "src/runtime/relax_vm/vm.cc"
// --- Implementations of backend and wasm runtime API. ---
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 95dc7af9fc..55c53bb8d5 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -35,12 +35,12 @@ export interface GPUDeviceDetectOutput {
*/
export async function detectGPUDevice(): Promise<GPUDeviceDetectOutput |
undefined> {
if (typeof navigator !== "undefined" && navigator.gpu !== undefined) {
- const adapter = await
navigator.gpu.requestAdapter({"powerPreference":"high-performance"});
+ const adapter = await navigator.gpu.requestAdapter({ "powerPreference":
"high-performance" });
if (adapter == null) {
throw Error("Cannot find adapter that matches the request");
}
const computeMB = (value: number) => {
- return Math.ceil(value / (1 << 20)) + "MB";
+ return Math.ceil(value / (1 << 20)) + "MB";
}
// more detailed error message
@@ -77,7 +77,7 @@ export async function detectGPUDevice():
Promise<GPUDeviceDetectOutput | undefin
}
const requiredMaxComputeWorkgroupStorageSize = 32 << 10;
- if (requiredMaxComputeWorkgroupStorageSize>
adapter.limits.maxComputeWorkgroupStorageSize) {
+ if (requiredMaxComputeWorkgroupStorageSize >
adapter.limits.maxComputeWorkgroupStorageSize) {
throw Error(
`Cannot initialize runtime because of requested
maxComputeWorkgroupStorageSize ` +
`exceeds limit. requested=${requiredMaxComputeWorkgroupStorageSize}, `
+
@@ -85,7 +85,16 @@ export async function detectGPUDevice():
Promise<GPUDeviceDetectOutput | undefin
);
}
- const requiredFeatures : GPUFeatureName[] = [];
+ const requiredMaxStorageBuffersPerShaderStage = 10; // default is 8
+ if (requiredMaxStorageBuffersPerShaderStage >
adapter.limits.maxStorageBuffersPerShaderStage) {
+ throw Error(
+ `Cannot initialize runtime because of requested
maxStorageBuffersPerShaderStage ` +
+ `exceeds limit. requested=${requiredMaxStorageBuffersPerShaderStage},
` +
+ `limit=${adapter.limits.maxStorageBuffersPerShaderStage}. `
+ );
+ }
+
+ const requiredFeatures: GPUFeatureName[] = [];
// Always require f16 if available
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
@@ -97,6 +106,7 @@ export async function detectGPUDevice():
Promise<GPUDeviceDetectOutput | undefin
maxBufferSize: requiredMaxBufferSize,
maxStorageBufferBindingSize: requiredMaxStorageBufferBindingSize,
maxComputeWorkgroupStorageSize: requiredMaxComputeWorkgroupStorageSize,
+ maxStorageBuffersPerShaderStage:
requiredMaxStorageBuffersPerShaderStage,
},
requiredFeatures
});
@@ -110,7 +120,7 @@ export async function detectGPUDevice():
Promise<GPUDeviceDetectOutput | undefin
}
}
-const canvasRenderWGSL =`
+const canvasRenderWGSL = `
@group(0) @binding(0) var my_sampler : sampler;
@group(0) @binding(1) var my_texture : texture_2d<f32>;
@@ -193,7 +203,7 @@ class CanvasRenderManager implements Disposable {
}),
entryPoint: "fragment_main",
targets: [{
- format: this.canvasTextureFormat,
+ format: this.canvasTextureFormat,
}],
},
primitive: {
@@ -215,7 +225,7 @@ class CanvasRenderManager implements Disposable {
}),
entryPoint: "fragment_clear",
targets: [{
- format: this.canvasTextureFormat,
+ format: this.canvasTextureFormat,
}],
},
primitive: {
@@ -285,7 +295,7 @@ class CanvasRenderManager implements Disposable {
bytesPerRow: this.stagingTexture.width * 4
}, {
texture: this.stagingTexture
- },{
+ }, {
width: this.stagingTexture.width,
height: this.stagingTexture.height
});
@@ -314,7 +324,7 @@ class CanvasRenderManager implements Disposable {
this.device.queue.submit([commandEncoder.finish()]);
}
- dispose() : void {
+ dispose(): void {
this.stagingTexture.destroy();
}
}
@@ -453,7 +463,7 @@ export class WebGPUContext {
* @param code The shader data(in WGSL)
* @returns The shader
*/
- createShader(finfo: FunctionInfo, code: string) : Function {
+ createShader(finfo: FunctionInfo, code: string): Function {
return this.createShadeInternal(finfo, code, false) as Function;
}
@@ -465,7 +475,7 @@ export class WebGPUContext {
* @param code The shader data(in WGSL)
* @returns The shader
*/
- async createShaderAsync(finfo: FunctionInfo, code: string) :
Promise<Function> {
+ async createShaderAsync(finfo: FunctionInfo, code: string):
Promise<Function> {
return await (this.createShadeInternal(finfo, code, true) as
Promise<Function>);
}
@@ -474,8 +484,8 @@ export class WebGPUContext {
* \param nbytes The minimum size.
* \return The allocated buffer
*/
- private getPodArgsBuffer(nbytes: number) : GPUBuffer {
- let buffer : GPUBuffer | undefined = undefined;
+ private getPodArgsBuffer(nbytes: number): GPUBuffer {
+ let buffer: GPUBuffer | undefined = undefined;
if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) {
buffer = this.podArgStagingBuffers.shift();
}
@@ -538,8 +548,8 @@ export class WebGPUContext {
const layoutEntries: Array<GPUBindGroupLayoutEntry> = [];
- const bufferArgIndices : Array<number> = [];
- const podArgIndices : Array<number> = [];
+ const bufferArgIndices: Array<number> = [];
+ const podArgIndices: Array<number> = [];
for (let i = 0; i < finfo.arg_types.length; ++i) {
const dtype = finfo.arg_types[i];
@@ -547,7 +557,7 @@ export class WebGPUContext {
layoutEntries.push({
binding: bufferArgIndices.length,
visibility: GPUShaderStage.COMPUTE,
- buffer : {
+ buffer: {
type: paramWriteAccess[bufferArgIndices.length] ? "storage" :
"read-only-storage"
}
});
@@ -564,7 +574,7 @@ export class WebGPUContext {
layoutEntries.push({
binding: bufferArgIndices.length,
visibility: GPUShaderStage.COMPUTE,
- buffer : {
+ buffer: {
type: "uniform"
}
});
@@ -573,14 +583,14 @@ export class WebGPUContext {
entries: layoutEntries
});
const pipelineLayout = this.device.createPipelineLayout({
- bindGroupLayouts: [ bindGroupLayout ]
+ bindGroupLayouts: [bindGroupLayout]
});
// Function to create the pipeline.
- const createShaderFunc = (pipeline: GPUComputePipeline): Function => {
+ const createShaderFunc = (pipeline: GPUComputePipeline): Function => {
const submitShader = (...args: Array<GPUPointer | number>): void => {
if (this.debugShaderSubmitLimit != -1 &&
- this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
+ this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
this.shaderSubmitCounter += 1;
return;
}
@@ -675,8 +685,8 @@ export class WebGPUContext {
if (this.debugLogFinish) {
const currCounter = this.shaderSubmitCounter;
- this.device.queue.onSubmittedWorkDone().then(()=> {
- console.log("["+ currCounter + "][Debug] finish shader" +
finfo.name);
+ this.device.queue.onSubmittedWorkDone().then(() => {
+ console.log("[" + currCounter + "][Debug] finish shader" +
finfo.name);
});
}
this.shaderSubmitCounter += 1;
@@ -799,7 +809,14 @@ export class WebGPUContext {
nbytes: number
): void {
// Perhaps it would be more useful to use a staging buffer?
- const rawBytes = this.memory.loadRawBytes(from, nbytes);
+ let rawBytes = this.memory.loadRawBytes(from, nbytes);
+ if (rawBytes.length % 4 !== 0) {
+ // writeBuffer requires length to be multiples of 4, so we pad here
+ const toPad = 4 - rawBytes.length % 4;
+ rawBytes = new Uint8Array(rawBytes.length + toPad);
+ rawBytes.set(rawBytes);
+ nbytes = nbytes + toPad;
+ }
this.device.queue.writeBuffer(
this.gpuBufferFromPtr(to),
toOffset,