This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b04b1acf40 [Web] Compatibility with PagedKVCache in WebGPU (#16554)
b04b1acf40 is described below

commit b04b1acf409131a3880a3a4c7824d70e83b13456
Author: Charlie Ruan <[email protected]>
AuthorDate: Mon Feb 12 13:52:07 2024 -0500

    [Web] Compatibility with PagedKVCache in WebGPU (#16554)
    
    This PR introduces various WebGPU changes to accommodate the new 
`PagedKVCache` interface. All changes below are essential for making models 
that use PagedKVCache runnable under WebGPU:
    
    - Require exactly same-dtype matching for WebGPU smem reuse in 
`storage_rewrite.cc`
    - Rename `AttentionKVCache` to `AttentionKVCacheLegacy` for the old KVcache 
interface in `lm_support.cc`; include `paged_kv_cache.cc` when making 
`wasm_runtime` subsequently
    - In WebGPU codegen:
      - Declare local variables within the function scope rather than the 
module scope
      - Generate `while (true)` rather than `while (1)`
    - Require 10 `maxStorageBuffersPerShaderStage` rather than the default 8 
from the WebGPU device when initializing runtime; this is required for new 
kernels introduced in PagedKVCache
    - In `deviceCopyToCPU()`, when raw bytes to write are not multiples of 4, 
we pad them, as required by WebGPU's `writeBuffer()`.
    
    ---------
    
    Co-authored-by: Rick Zhou <[email protected]>
---
 src/runtime/relax_vm/lm_support.cc    | 40 +++++++++++-----------
 src/target/source/codegen_webgpu.cc   | 29 ++++++++++++----
 src/target/source/codegen_webgpu.h    |  1 +
 src/tir/transforms/storage_rewrite.cc |  5 +--
 web/emcc/wasm_runtime.cc              |  1 +
 web/src/webgpu.ts                     | 63 ++++++++++++++++++++++-------------
 6 files changed, 88 insertions(+), 51 deletions(-)

diff --git a/src/runtime/relax_vm/lm_support.cc 
b/src/runtime/relax_vm/lm_support.cc
index ecaacb7770..fccff2cecd 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -226,18 +226,19 @@ class AttentionKVCacheObj : public Object {
   }
 
   static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
-  static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
+  static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
   TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
 };
 
 /*! \brief reference to closure. */
-class AttentionKVCache : public ObjectRef {
+class AttentionKVCacheLegacy : public ObjectRef {
  public:
   /*!
    * \brief Create the attention kv cache.
    * \param init_data The initial reserved.
    */
-  static AttentionKVCache Create(NDArray init_data, ShapeTuple reserve_shape, 
int init_fill_count) {
+  static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple 
reserve_shape,
+                                       int init_fill_count) {
     auto n = make_object<AttentionKVCacheObj>();
     n->data = NDArray::Empty(reserve_shape, init_data->dtype, 
init_data->device);
     n->fill_count = 0;
@@ -246,10 +247,10 @@ class AttentionKVCache : public ObjectRef {
       n->fill_count = init_fill_count;
       n->window_attention_current_pos = init_fill_count;  // window attention 
only
     }
-    return AttentionKVCache(n);
+    return AttentionKVCacheLegacy(n);
   }
 
-  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, ObjectRef, 
AttentionKVCacheObj);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef, 
AttentionKVCacheObj);
 };
 
 TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
@@ -258,24 +259,24 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
 //  Register runtime functions
 //-------------------------------------------------
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create")
-    .set_body_typed(AttentionKVCache::Create);
+    .set_body_typed(AttentionKVCacheLegacy::Create);
 
-AttentionKVCache AttentionKVCacheUpdate(AttentionKVCache cache, NDArray value) 
{
+AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, 
NDArray value) {
   cache->Update(value);
   return cache;
 }
 
 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update").set_body_typed(AttentionKVCacheUpdate);
 
-AttentionKVCache AttentionKVCacheAppend(AttentionKVCache cache, NDArray value) 
{
+AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, 
NDArray value) {
   cache->Append(value);
   return cache;
 }
 
 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append").set_body_typed(AttentionKVCacheAppend);
 
-AttentionKVCache AttentionKVCacheWindowOverride(AttentionKVCache cache, 
NDArray value,
-                                                int64_t max_cache_size) {
+AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy 
cache, NDArray value,
+                                                      int64_t max_cache_size) {
   cache->WindowOverride(value, max_cache_size);
   return cache;
 }
@@ -283,9 +284,10 @@ AttentionKVCache 
AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
     .set_body_typed(AttentionKVCacheWindowOverride);
 
-AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache 
cache, NDArray value,
-                                                         int64_t 
max_cache_size,
-                                                         int64_t 
num_attention_sinks) {
+AttentionKVCacheLegacy 
AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache,
+                                                               NDArray value,
+                                                               int64_t 
max_cache_size,
+                                                               int64_t 
num_attention_sinks) {
   cache->WindowOverride(value, max_cache_size, num_attention_sinks);
   return cache;
 }
@@ -293,7 +295,7 @@ AttentionKVCache 
AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache cache,
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks")
     .set_body_typed(AttentionKVCacheWindowOverrideWithSinks);
 
-NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
+NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ShapeTuple shape) {
   return cache->View(shape);
 }
 
@@ -302,7 +304,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view")
       CHECK(args.size() == 1 || args.size() == 2)
           << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 
arguments, but got "
           << args.size() << ".";
-      AttentionKVCache cache = args[0];
+      AttentionKVCacheLegacy cache = args[0];
       if (args.size() == 2) {
         ShapeTuple shape = args[1];
         *rv = cache->View(shape);
@@ -316,8 +318,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view")
       }
     });
 
-void AttentionKVCacheArrayPopN(Array<AttentionKVCache> caches, int64_t n) {
-  for (AttentionKVCache cache : caches) {
+void AttentionKVCacheArrayPopN(Array<AttentionKVCacheLegacy> caches, int64_t 
n) {
+  for (AttentionKVCacheLegacy cache : caches) {
     cache->PopN(static_cast<size_t>(n));
   }
 }
@@ -325,8 +327,8 @@ void AttentionKVCacheArrayPopN(Array<AttentionKVCache> 
caches, int64_t n) {
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn")
     .set_body_typed(AttentionKVCacheArrayPopN);
 
-void AttentionKVCacheArrayClear(Array<AttentionKVCache> caches) {
-  for (AttentionKVCache cache : caches) {
+void AttentionKVCacheArrayClear(Array<AttentionKVCacheLegacy> caches) {
+  for (AttentionKVCacheLegacy cache : caches) {
     cache->Clear();
   }
 }
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index 1702699ac2..5ede16d2f4 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -599,13 +599,15 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
     PrintType(op->dtype, this->decl_stream);
     this->decl_stream << ", " << constant_size << ">;\n";
   } else if (storage_scope.rank == runtime::StorageRank::kLocal) {
-    this->decl_stream << "var<private> " << vid << " : array<";
-    PrintType(op->dtype, this->decl_stream);
-    this->decl_stream << ", " << constant_size << ">;\n";
-    // this->PrintIndent();
-    // this->stream << "var " << vid << " : array<";
-    // PrintType(op->dtype, this->stream);
-    // this->stream << ", " << constant_size << ">;\n";
+    // TODO(Charlie): These code would cause non-uniformity as it introduces 
variables in module
+    // scope rather than function scope; but it was included for some unknown 
reasons; kept for now.
+    // this->decl_stream << "var<private> " << vid << " : array<";
+    // PrintType(op->dtype, this->decl_stream);
+    // this->decl_stream << ", " << constant_size << ">;\n";
+    this->PrintIndent();
+    this->stream << "var " << vid << " : array<";
+    PrintType(op->dtype, this->stream);
+    this->stream << ", " << constant_size << ">;\n";
   } else {
     LOG(FATAL) << "WebGPU: Do not support storage scope: " << 
storage_scope.to_string();
   }
@@ -636,6 +638,19 @@ void CodeGenWebGPU::VisitStmt_(const AllocateConstNode* 
op) {
   LOG(FATAL) << "WebGPU: do not support alloc const";
 }
 
+void CodeGenWebGPU::VisitStmt_(const WhileNode* op) {
+  PrintIndent();
+  stream << "while (true) {\n";
+  int while_scope = BeginScope();
+  std::string cond = PrintExpr(op->condition);
+  PrintIndent();
+  stream << "if (!(" << cond << ")) { break; }\n";
+  PrintStmt(op->body);
+  this->EndScope(while_scope);
+  PrintIndent();
+  stream << "}\n";
+}
+
 //-------------------------------------------------
 // WebGPUSourceModule to enable export
 //-------------------------------------------------
diff --git a/src/target/source/codegen_webgpu.h 
b/src/target/source/codegen_webgpu.h
index f12cd3430d..cf642b9e07 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -73,6 +73,7 @@ class CodeGenWebGPU final : public CodeGenC {
   void VisitStmt_(const AllocateNode* op) final;
   void VisitStmt_(const AssertStmtNode* op) final;
   void VisitStmt_(const AllocateConstNode* op) final;
+  void VisitStmt_(const WhileNode* op) final;
 
  private:
   /*!
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 991c48219b..dd27397f36 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1724,8 +1724,9 @@ Pass StorageRewrite() {
     }
 
     Optional<Target> target = f->GetAttr<Target>("target");
-    if (target.defined() && target.value()->kind->name == "vulkan") {
-      // Require exactly same-dtype matching in smem reuse for Vulkan
+    if (target.defined() &&
+        (target.value()->kind->name == "vulkan" || target.value()->kind->name 
== "webgpu")) {
+      // Require exactly same-dtype matching in smem reuse for Vulkan and 
WebGPU
       reuse_require_exact_matched_dtype = true;
     }
     auto* n = f.CopyOnWrite();
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 311bbd9971..be9704eaef 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -60,6 +60,7 @@
 #include "src/runtime/relax_vm/executable.cc"
 #include "src/runtime/relax_vm/lm_support.cc"
 #include "src/runtime/relax_vm/ndarray_cache_support.cc"
+#include "src/runtime/relax_vm/paged_kv_cache.cc"
 #include "src/runtime/relax_vm/vm.cc"
 
 // --- Implementations of backend and wasm runtime API. ---
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 95dc7af9fc..55c53bb8d5 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -35,12 +35,12 @@ export interface GPUDeviceDetectOutput {
  */
 export async function detectGPUDevice(): Promise<GPUDeviceDetectOutput | 
undefined> {
   if (typeof navigator !== "undefined" && navigator.gpu !== undefined) {
-    const adapter = await 
navigator.gpu.requestAdapter({"powerPreference":"high-performance"});
+    const adapter = await navigator.gpu.requestAdapter({ "powerPreference": 
"high-performance" });
     if (adapter == null) {
       throw Error("Cannot find adapter that matches the request");
     }
     const computeMB = (value: number) => {
-      return Math.ceil(value  / (1 << 20)) + "MB";
+      return Math.ceil(value / (1 << 20)) + "MB";
     }
 
     // more detailed error message
@@ -77,7 +77,7 @@ export async function detectGPUDevice(): 
Promise<GPUDeviceDetectOutput | undefin
     }
 
     const requiredMaxComputeWorkgroupStorageSize = 32 << 10;
-    if (requiredMaxComputeWorkgroupStorageSize> 
adapter.limits.maxComputeWorkgroupStorageSize) {
+    if (requiredMaxComputeWorkgroupStorageSize > 
adapter.limits.maxComputeWorkgroupStorageSize) {
       throw Error(
         `Cannot initialize runtime because of requested 
maxComputeWorkgroupStorageSize ` +
         `exceeds limit. requested=${requiredMaxComputeWorkgroupStorageSize}, ` 
+
@@ -85,7 +85,16 @@ export async function detectGPUDevice(): 
Promise<GPUDeviceDetectOutput | undefin
       );
     }
 
-    const requiredFeatures : GPUFeatureName[] = [];
+    const requiredMaxStorageBuffersPerShaderStage = 10;  // default is 8
+    if (requiredMaxStorageBuffersPerShaderStage > 
adapter.limits.maxStorageBuffersPerShaderStage) {
+      throw Error(
+        `Cannot initialize runtime because of requested 
maxStorageBuffersPerShaderStage ` +
+        `exceeds limit. requested=${requiredMaxStorageBuffersPerShaderStage}, 
` +
+        `limit=${adapter.limits.maxStorageBuffersPerShaderStage}. `
+      );
+    }
+
+    const requiredFeatures: GPUFeatureName[] = [];
     // Always require f16 if available
     if (adapter.features.has("shader-f16")) {
       requiredFeatures.push("shader-f16");
@@ -97,6 +106,7 @@ export async function detectGPUDevice(): 
Promise<GPUDeviceDetectOutput | undefin
         maxBufferSize: requiredMaxBufferSize,
         maxStorageBufferBindingSize: requiredMaxStorageBufferBindingSize,
         maxComputeWorkgroupStorageSize: requiredMaxComputeWorkgroupStorageSize,
+        maxStorageBuffersPerShaderStage: 
requiredMaxStorageBuffersPerShaderStage,
       },
       requiredFeatures
     });
@@ -110,7 +120,7 @@ export async function detectGPUDevice(): 
Promise<GPUDeviceDetectOutput | undefin
   }
 }
 
-const canvasRenderWGSL =`
+const canvasRenderWGSL = `
 @group(0) @binding(0) var my_sampler : sampler;
 @group(0) @binding(1) var my_texture : texture_2d<f32>;
 
@@ -193,7 +203,7 @@ class CanvasRenderManager implements Disposable {
         }),
         entryPoint: "fragment_main",
         targets: [{
-            format: this.canvasTextureFormat,
+          format: this.canvasTextureFormat,
         }],
       },
       primitive: {
@@ -215,7 +225,7 @@ class CanvasRenderManager implements Disposable {
         }),
         entryPoint: "fragment_clear",
         targets: [{
-            format: this.canvasTextureFormat,
+          format: this.canvasTextureFormat,
         }],
       },
       primitive: {
@@ -285,7 +295,7 @@ class CanvasRenderManager implements Disposable {
       bytesPerRow: this.stagingTexture.width * 4
     }, {
       texture: this.stagingTexture
-    },{
+    }, {
       width: this.stagingTexture.width,
       height: this.stagingTexture.height
     });
@@ -314,7 +324,7 @@ class CanvasRenderManager implements Disposable {
     this.device.queue.submit([commandEncoder.finish()]);
   }
 
-  dispose() : void {
+  dispose(): void {
     this.stagingTexture.destroy();
   }
 }
@@ -453,7 +463,7 @@ export class WebGPUContext {
    * @param code The shader data(in WGSL)
    * @returns The shader
    */
-  createShader(finfo: FunctionInfo, code: string) : Function {
+  createShader(finfo: FunctionInfo, code: string): Function {
     return this.createShadeInternal(finfo, code, false) as Function;
   }
 
@@ -465,7 +475,7 @@ export class WebGPUContext {
    * @param code The shader data(in WGSL)
    * @returns The shader
    */
-  async createShaderAsync(finfo: FunctionInfo, code: string) : 
Promise<Function> {
+  async createShaderAsync(finfo: FunctionInfo, code: string): 
Promise<Function> {
     return await (this.createShadeInternal(finfo, code, true) as 
Promise<Function>);
   }
 
@@ -474,8 +484,8 @@ export class WebGPUContext {
    * \param nbytes The minimum size.
    * \return The allocated buffer
    */
-  private getPodArgsBuffer(nbytes: number) : GPUBuffer {
-    let buffer : GPUBuffer | undefined = undefined;
+  private getPodArgsBuffer(nbytes: number): GPUBuffer {
+    let buffer: GPUBuffer | undefined = undefined;
     if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) {
       buffer = this.podArgStagingBuffers.shift();
     }
@@ -538,8 +548,8 @@ export class WebGPUContext {
 
 
     const layoutEntries: Array<GPUBindGroupLayoutEntry> = [];
-    const bufferArgIndices : Array<number> = [];
-    const podArgIndices : Array<number> = [];
+    const bufferArgIndices: Array<number> = [];
+    const podArgIndices: Array<number> = [];
 
     for (let i = 0; i < finfo.arg_types.length; ++i) {
       const dtype = finfo.arg_types[i];
@@ -547,7 +557,7 @@ export class WebGPUContext {
         layoutEntries.push({
           binding: bufferArgIndices.length,
           visibility: GPUShaderStage.COMPUTE,
-          buffer :  {
+          buffer: {
             type: paramWriteAccess[bufferArgIndices.length] ? "storage" : 
"read-only-storage"
           }
         });
@@ -564,7 +574,7 @@ export class WebGPUContext {
     layoutEntries.push({
       binding: bufferArgIndices.length,
       visibility: GPUShaderStage.COMPUTE,
-      buffer :  {
+      buffer: {
         type: "uniform"
       }
     });
@@ -573,14 +583,14 @@ export class WebGPUContext {
       entries: layoutEntries
     });
     const pipelineLayout = this.device.createPipelineLayout({
-      bindGroupLayouts: [ bindGroupLayout ]
+      bindGroupLayouts: [bindGroupLayout]
     });
 
     // Function to create the pipeline.
-    const createShaderFunc =  (pipeline: GPUComputePipeline): Function => {
+    const createShaderFunc = (pipeline: GPUComputePipeline): Function => {
       const submitShader = (...args: Array<GPUPointer | number>): void => {
         if (this.debugShaderSubmitLimit != -1 &&
-            this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
+          this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
           this.shaderSubmitCounter += 1;
           return;
         }
@@ -675,8 +685,8 @@ export class WebGPUContext {
 
         if (this.debugLogFinish) {
           const currCounter = this.shaderSubmitCounter;
-          this.device.queue.onSubmittedWorkDone().then(()=> {
-            console.log("["+ currCounter + "][Debug] finish shader" + 
finfo.name);
+          this.device.queue.onSubmittedWorkDone().then(() => {
+            console.log("[" + currCounter + "][Debug] finish shader" + 
finfo.name);
           });
         }
         this.shaderSubmitCounter += 1;
@@ -799,7 +809,14 @@ export class WebGPUContext {
     nbytes: number
   ): void {
     // Perhaps it would be more useful to use a staging buffer?
-    const rawBytes = this.memory.loadRawBytes(from, nbytes);
+    let rawBytes = this.memory.loadRawBytes(from, nbytes);
+    if (rawBytes.length % 4 !== 0) {
+      // writeBuffer requires length to be multiples of 4, so we pad here
+      const toPad = 4 - rawBytes.length % 4;
+      rawBytes = new Uint8Array(rawBytes.length + toPad);
+      rawBytes.set(rawBytes);
+      nbytes = nbytes + toPad;
+    }
     this.device.queue.writeBuffer(
       this.gpuBufferFromPtr(to),
       toOffset,

Reply via email to