This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 419a8c861e Batched GPU dispatch and object caching for WebGPU runtime
(#18871)
419a8c861e is described below
commit 419a8c861e6766dbbde5cd184fb77651f2b51587
Author: Miti <[email protected]>
AuthorDate: Fri Mar 6 01:13:09 2026 +0100
Batched GPU dispatch and object caching for WebGPU runtime (#18871)
## Summary
- Batch compute dispatches into a single GPUCommandEncoder, flushing on
sync/readback instead of per-dispatch submit to reduce JS↔GPU transition
overhead during LLM decode
- Cache uniform buffers (FIFO/512), bind groups (FIFO/256), shape
tuples, and pool MAP_READ staging buffers to eliminate redundant GPU
object creation
- Fix padding self-assignment bug in `deviceCopyToGPU`
---
web/src/cache_state.ts | 175 +++++++++++++++++++++++++++++++++++++++
web/src/index.ts | 1 +
web/src/runtime.ts | 14 +++-
web/src/webgpu.ts | 217 +++++++++++++++++++++++++++++++++++++------------
4 files changed, 351 insertions(+), 56 deletions(-)
diff --git a/web/src/cache_state.ts b/web/src/cache_state.ts
new file mode 100644
index 0000000000..c571a071b8
--- /dev/null
+++ b/web/src/cache_state.ts
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Caching utilities for the TVM web runtime.
+ *
+ * Provides a generic LRUCache and a CacheState container that manages
+ * domain-specific caches used by the WebGPU runtime.
+ */
+import { Disposable } from "./types";
+
+/**
+ * A generic LRU (Least Recently Used) cache with bounded size.
+ *
+ * Entries are evicted in insertion order when the cache exceeds `maxSize`.
+ * Uses a Map to maintain insertion order for O(1) LRU eviction.
+ *
+ * @typeParam K - Cache key type.
+ * @typeParam V - Cache value type.
+ */
+export class LRUCache<K, V> {
+ private cache: Map<K, V> = new Map();
+ private readonly maxSize: number;
+ /** Optional callback invoked when an entry is evicted. */
+ private readonly onEvict?: (key: K, value: V) => void;
+
+ constructor(maxSize: number, onEvict?: (key: K, value: V) => void) {
+ this.maxSize = maxSize;
+ this.onEvict = onEvict;
+ }
+
+ /**
+ * Get a value from the cache, constructing it via `constructor` on miss.
+ *
+ * On hit: moves the entry to most-recently-used position and returns it.
+ * On miss: calls `constructor()` to create the value, inserts it, and
+ * returns it. If the cache is full, the least-recently-used entry is
+ * evicted first.
+ *
+ * @param key The cache key.
+ * @param constructor Factory function called on cache miss to produce the
value.
+ * @returns The cached or newly constructed value.
+ */
+ get(key: K, constructor: () => V): V {
+ const existing = this.cache.get(key);
+ if (existing !== undefined) {
+ // Move to most-recently-used position
+ this.cache.delete(key);
+ this.cache.set(key, existing);
+ return existing;
+ }
+ // Evict LRU entry if at capacity
+ if (this.cache.size >= this.maxSize) {
+ const oldest = this.cache.keys().next().value;
+ if (oldest !== undefined) {
+ if (this.onEvict) {
+ this.onEvict(oldest, this.cache.get(oldest)!);
+ }
+ this.cache.delete(oldest);
+ }
+ }
+ const value = constructor();
+ this.cache.set(key, value);
+ return value;
+ }
+
+ /**
+ * Check whether eviction would be needed for a new entry.
+ *
+ * Useful when the caller needs to perform side effects before eviction
+ * (e.g. flushing pending GPU commands before destroying an evicted buffer).
+ *
+ * @param key The key to check.
+ * @returns true if inserting `key` would trigger eviction of another entry.
+ */
+ needEviction(key: K): boolean {
+ if (this.cache.has(key)) return false;
+ return this.cache.size >= this.maxSize;
+ }
+
+ /**
+ * Clear all cached entries.
+ *
+ * Does not dispose values — the caller is responsible for cleanup
+ * (e.g. destroying GPU buffers) before calling invalidate.
+ */
+ invalidate(): void {
+ this.cache.clear();
+ }
+
+ /** Number of entries currently in the cache. */
+ get size(): number {
+ return this.cache.size;
+ }
+
+ /** Iterate over all cached values (for disposal). */
+ values(): IterableIterator<V> {
+ return this.cache.values();
+ }
+}
+
+/**
+ * CacheState manages domain-specific caches for the WebGPU runtime.
+ *
+ * Currently contains:
+ * - **shapeCache**: Caches TVM ShapeTuple objects keyed by dimension string.
+ * - Why: `makeShapeTuple()` is called on every tensor operation, crossing
+ * the JS→WASM FFI boundary each time. During LLM decode, the same shapes
+ * repeat every token (e.g. [1,32,128]), so caching avoids thousands of
+ * redundant FFI round-trips.
+ * - Invalidation: Never. Shape tuples are immutable value objects that
+ * remain valid for the lifetime of the TVM instance.
+ *
+ * Future additions (follow-up PR):
+ * - **uniformCache**: Caches GPU uniform buffers keyed by content hash.
+ * - Why: Many dispatches use identical scalar arguments (matrix dims, etc.).
+ * Reusing the buffer avoids `createBuffer` + `writeBuffer` overhead.
+ * - Invalidation: Must invalidate on any GPU buffer deallocation, since
+ * buffer pointers can be reused by the allocator, making cached entries
+ * that reference the old buffer stale.
+ */
+export class CacheState {
+ /**
+ * Cache for TVM ShapeTuple objects.
+ *
+ * Key: comma-separated dimension string, e.g. "1,32,128"
+ * Value: TVM ShapeTuple object (Disposable)
+ *
+ * Invalidation rule: None required — shape tuples are immutable.
+ */
+ readonly shapeCache: LRUCache<string, Disposable>;
+
+ constructor(shapeCacheSize: number = 256) {
+ this.shapeCache = new LRUCache<string, Disposable>(
+ shapeCacheSize,
+ (_key, value) => value.dispose()
+ );
+ }
+
+ /**
+ * Compute the cache key for a shape tuple.
+ *
+ * @param shape Array of dimension values.
+ * @returns String key suitable for shapeCache lookup.
+ */
+ static computeShapeKey(shape: Array<number>): string {
+ return shape.toString();
+ }
+
+ /**
+ * Dispose all cached objects and clear all caches.
+ */
+ dispose(): void {
+ for (const obj of this.shapeCache.values()) {
+ obj.dispose();
+ }
+ this.shapeCache.invalidate();
+ }
+}
diff --git a/web/src/index.ts b/web/src/index.ts
index 868a26623a..e26c84cc38 100644
--- a/web/src/index.ts
+++ b/web/src/index.ts
@@ -35,4 +35,5 @@ export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
export { assert, wasmPath, LinearCongruentialGenerator } from "./support";
export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
+export { LRUCache, CacheState } from "./cache_state";
export { createPolyfillWASI } from "./compact";
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index d3f7216aae..076efdb5fd 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -27,6 +27,7 @@ import { assert, StringToUint8Array,
LinearCongruentialGenerator } from "./suppo
import { Environment } from "./environment";
import { AsyncifyHandler } from "./asyncify";
import { FunctionInfo, WebGPUContext } from "./webgpu";
+import { CacheState } from "./cache_state";
import {
ArtifactCache,
ArtifactCacheTemplate,
@@ -859,6 +860,7 @@ export class Instance implements Disposable {
private initProgressCallback: Array<InitProgressCallback> = [];
private rng: LinearCongruentialGenerator;
private deviceLostIsError = true; // whether device.lost is due to actual
error or dispose()
+ private cacheState: CacheState = new CacheState();
/**
* Internal function(registered by the runtime)
@@ -954,6 +956,8 @@ export class Instance implements Disposable {
dispose(): void {
this.deviceLostIsError = false; // prevent dispose to trigger device.lost
error
// order matters
+ // dispose caches before ctx
+ this.cacheState.dispose();
// ctx release goes back into lib.
this.ctx.dispose();
this.lib.dispose();
@@ -1674,8 +1678,14 @@ export class Instance implements Disposable {
* @returns The created shape tuple.
*/
makeShapeTuple(shape: Array<number>): TVMObject {
- const shapeArray = shape.map((value) => new Scalar(value, "int"));
- return this.ctx.makeShapeTuple(...shapeArray);
+ const key = CacheState.computeShapeKey(shape);
+ return this.cacheState.shapeCache.get(key, () => {
+ const shapeArray = shape.map((value) => new Scalar(value, "int"));
+ const tuple = this.ctx.makeShapeTuple(...shapeArray);
+ // Detach from scope so the cached object survives across scopes.
+ this.detachFromCurrentScope(tuple);
+ return tuple;
+ }) as TVMObject;
}
/**
* Get type index from type key.
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index fe5d7f5553..55d188516d 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -127,7 +127,6 @@ export async function detectGPUDevice(powerPreference:
"low-power" | "high-perfo
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
}
-
// requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
// issue when building. However, it is still needed for older browsers,
hence `as any`.
const adapterInfo = adapter.info || await (adapter as
any).requestAdapterInfo();
@@ -402,10 +401,25 @@ export class WebGPUContext {
// internal data
private bufferTable: Array<GPUBuffer | undefined> = [undefined];
private bufferTableFreeId: Array<number> = [];
- private podArgStagingBuffers: Array<GPUBuffer> = [];
private canvasRenderManager?: CanvasRenderManager = undefined;
- // number of pod arg staging buffers
- private maxNumPodArgsStagingBuffers = 2;
+ // Pool of MAP_READ staging buffers to avoid per-copy create/destroy overhead
+ private readStagingBufferPool: Array<{ buffer: GPUBuffer; size: number }> =
[];
+ private maxReadStagingBuffers = 4;
+ // Pending mapAsync promise from the last GPU→CPU copy.
+ // Used in sync() as a fast path: if the last queue operation was a
+ // GPU→CPU copy, awaiting its mapAsync is sufficient (no need for
+ // the heavier onSubmittedWorkDone). Reset to null after any non-copy
+ // queue submission so we fall back to onSubmittedWorkDone.
+ private pendingGPUToCPUCopy: Promise<void> | null = null;
+ // Batched command encoding: accumulate compute passes in a single encoder,
+ // submit only on flush to reduce JS-native transition overhead.
+ private pendingEncoder: GPUCommandEncoder | null = null;
+ // Pool of uniform buffers reused across flushes. Each dispatch in a batch
+ // gets its own buffer (indexed by pendingDispatchCount). The pool grows
+ // as needed but buffers are never destroyed — just reused next batch.
+ private uniformBufferPool: Array<GPUBuffer> = [];
+ private uniformBufferPoolSizes: Array<number> = [];
+ private pendingDispatchCount = 0;
// flags for debugging
// stats of the runtime.
// peak allocation
@@ -426,26 +440,63 @@ export class WebGPUContext {
this.device = device;
}
+ /**
+ * Flush all pending compute passes by finishing and submitting the
+ * accumulated command encoder.
+ *
+ * Must be called before:
+ * - GPU→CPU readback (deviceCopyFromGPU)
+ * - CPU→GPU writes (deviceCopyToGPU, copyRawBytesToBuffer)
+ * - GPU↔GPU copies (deviceCopyWithinGPU)
+ * - Buffer deallocation (deviceFreeDataSpace)
+ * - Queue sync (sync)
+ */
+ flushCommands(): void {
+ if (this.pendingEncoder) {
+ this.device.queue.submit([this.pendingEncoder.finish()]);
+ this.pendingEncoder = null;
+ this.pendingDispatchCount = 0;
+ // A compute submission is now the last queue operation, so the
+ // GPU→CPU copy fast path in sync() is no longer valid.
+ this.pendingGPUToCPUCopy = null;
+ }
+ }
+
/**
* Dispose context.
*/
dispose() {
+ this.flushCommands();
this.canvasRenderManager?.dispose();
this.bufferTableFreeId = [];
while (this.bufferTable.length != 0) {
this.bufferTable.pop()?.destroy();
}
- while (this.podArgStagingBuffers.length != 0) {
- this.podArgStagingBuffers.pop()?.destroy();
+ for (const buf of this.uniformBufferPool) {
+ buf.destroy();
+ }
+ this.uniformBufferPool.length = 0;
+ this.uniformBufferPoolSizes.length = 0;
+ while (this.readStagingBufferPool.length != 0) {
+ this.readStagingBufferPool.pop()?.buffer.destroy();
}
this.device.destroy();
}
+
/**
* Wait for all pending GPU tasks to complete
*/
async sync(): Promise<void> {
- await this.device.queue.onSubmittedWorkDone();
+ // Flush any batched compute passes before waiting on the queue.
+ this.flushCommands();
+ if (this.pendingGPUToCPUCopy) {
+ const p = this.pendingGPUToCPUCopy;
+ this.pendingGPUToCPUCopy = null;
+ await p;
+ } else {
+ await this.device.queue.onSubmittedWorkDone();
+ }
}
/**
@@ -485,7 +536,8 @@ export class WebGPUContext {
toOffset: number,
nbytes: number
): void {
- // Perhaps it would be more useful to use a staging buffer?
+ // Flush batched compute passes before writing, to preserve execution
order.
+ this.flushCommands();
this.device.queue.writeBuffer(
this.gpuBufferFromPtr(toPtr),
toOffset,
@@ -534,36 +586,41 @@ export class WebGPUContext {
}
/**
- * Get the pod arg staging buffer
- * \param nbytes The minimum size.
- * \return The allocated buffer
+ * Get a uniform buffer from the per-dispatch pool.
+ *
+ * Each dispatch in a batched encoder needs its own uniform buffer because
+ * queue.writeBuffer() executes immediately while compute passes are
deferred.
+ * Reusing a shared buffer would overwrite data before earlier dispatches
+ * consume it.
+ *
+ * The pool grows as needed. Buffers are reused across flushes (indexed by
+ * dispatch position within the current batch). If the pool has no slot for
+ * this dispatch, we flush first — this submits all pending passes, resets
+ * pendingDispatchCount to 0, and allows reuse from the start of the pool.
+ *
+ * State after flush: the pending encoder and all bind group / buffer
+ * references from prior dispatches are submitted and consumed. The new
+ * dispatch starts a fresh encoder, so no stale state carries over.
+ *
+ * @param nbytes Minimum buffer size in bytes.
+ * @returns A GPUBuffer with UNIFORM | COPY_DST usage, at least nbytes large.
*/
- private getPodArgsBuffer(nbytes: number): GPUBuffer {
- let buffer: GPUBuffer | undefined = undefined;
- if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) {
- buffer = this.podArgStagingBuffers.shift();
- }
- // minimum of 16 bytes
- let allocSize = 16;
- if (buffer !== undefined) {
- allocSize = buffer.size;
- if (buffer.size < nbytes) {
- buffer.destroy();
- buffer = undefined;
- }
+ private getUniformFromPool(nbytes: number): GPUBuffer {
+ const dispatchIdx = this.pendingDispatchCount++;
+ if (dispatchIdx < this.uniformBufferPool.length &&
+ this.uniformBufferPoolSizes[dispatchIdx] >= nbytes) {
+ return this.uniformBufferPool[dispatchIdx];
}
- while (allocSize < nbytes) {
- allocSize *= 2;
+ // Destroy old undersized buffer if it exists.
+ if (dispatchIdx < this.uniformBufferPool.length) {
+ this.uniformBufferPool[dispatchIdx].destroy();
}
-
- if (buffer == undefined) {
- // create uniform buffer
- buffer = tryCreateBuffer(this.device, {
- size: allocSize,
- usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
- });
- }
- assert(nbytes <= buffer.size);
+ const buffer = this.device.createBuffer({
+ size: nbytes,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ this.uniformBufferPool[dispatchIdx] = buffer;
+ this.uniformBufferPoolSizes[dispatchIdx] = nbytes;
return buffer;
}
@@ -649,8 +706,12 @@ export class WebGPUContext {
return;
}
- const commandEncoder = this.device.createCommandEncoder();
- const compute = commandEncoder.beginComputePass();
+ // Reuse a single command encoder across dispatches; only flush on
sync/readback.
+ if (!this.pendingEncoder) {
+ this.pendingEncoder = this.device.createCommandEncoder();
+ }
+
+ const compute = this.pendingEncoder.beginComputePass();
compute.setPipeline(pipeline);
const bindGroupEntries: Array<GPUBindGroupEntry> = [];
const numBufferOrPodArgs = bufferArgIndices.length +
podArgIndices.length;
@@ -695,9 +756,9 @@ export class WebGPUContext {
});
}
- // push pod buffer
const sizeOfI32 = 4;
- const podArgBuffer = this.getPodArgsBuffer((podArgIndices.length + 1)
* sizeOfI32);
+ const bufBytes = (podArgIndices.length + 1) * sizeOfI32;
+ const podArgBuffer = this.getUniformFromPool(bufBytes);
const i32View = new Int32Array(podArgIndices.length + 1);
const u32View = new Uint32Array(i32View.buffer);
const f32View = new Float32Array(i32View.buffer);
@@ -732,12 +793,12 @@ export class WebGPUContext {
entries: bindGroupEntries
}));
- compute.dispatchWorkgroups(workDim[0], workDim[1], workDim[2])
- compute.end()
- const command = commandEncoder.finish();
- this.device.queue.submit([command]);
+ compute.dispatchWorkgroups(workDim[0], workDim[1], workDim[2]);
+ compute.end();
+ // In debug mode, flush immediately so we can observe each submission.
if (this.debugLogFinish) {
+ this.flushCommands();
const currCounter = this.shaderSubmitCounter;
this.device.queue.onSubmittedWorkDone().then(() => {
console.log("[" + currCounter + "][Debug] finish shader" +
finfo.name);
@@ -853,6 +914,10 @@ export class WebGPUContext {
assert(buffer !== undefined);
this.bufferTableFreeId.push(idx);
this.currAllocatedBytes -= buffer.size;
+ // Flush any pending compute passes that may reference this buffer
+ // before destroying it, otherwise queue.submit() will fail with
+ // "buffer used in submit while destroyed".
+ this.flushCommands();
buffer.destroy();
}
@@ -862,13 +927,17 @@ export class WebGPUContext {
toOffset: number,
nbytes: number
): void {
- // Perhaps it would be more useful to use a staging buffer?
+ // Flush batched compute passes before writing to a GPU buffer,
+ // otherwise the write may be reordered before pending dispatches
+ // that read from the same buffer.
+ this.flushCommands();
let rawBytes = this.memory.loadRawBytes(from, nbytes);
if (rawBytes.length % 4 !== 0) {
// writeBuffer requires length to be multiples of 4, so we pad here
const toPad = 4 - rawBytes.length % 4;
- rawBytes = new Uint8Array(rawBytes.length + toPad);
- rawBytes.set(rawBytes);
+ const padded = new Uint8Array(rawBytes.length + toPad);
+ padded.set(rawBytes);
+ rawBytes = padded;
nbytes = nbytes + toPad;
}
this.device.queue.writeBuffer(
@@ -880,17 +949,51 @@ export class WebGPUContext {
);
}
+ /**
+ * Get a MAP_READ staging buffer from the pool, or create one if none fits.
+ * Uses first-fit-by-size: returns the first pooled buffer >= nbytes.
+ */
+ private getOrCreateReadStagingBuffer(nbytes: number): GPUBuffer {
+ for (let i = 0; i < this.readStagingBufferPool.length; i++) {
+ if (this.readStagingBufferPool[i].size >= nbytes) {
+ const entry = this.readStagingBufferPool.splice(i, 1)[0];
+ return entry.buffer;
+ }
+ }
+ return tryCreateBuffer(this.device, {
+ size: nbytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ }
+
+ /**
+ * Return a MAP_READ staging buffer to the pool for reuse.
+ * Evicts the smallest buffer if the pool is full.
+ */
+ private recycleReadStagingBuffer(buf: GPUBuffer): void {
+ buf.unmap();
+ if (this.readStagingBufferPool.length >= this.maxReadStagingBuffers) {
+ // Evict smallest buffer to make room
+ let minIdx = 0;
+ for (let i = 1; i < this.readStagingBufferPool.length; i++) {
+ if (this.readStagingBufferPool[i].size <
this.readStagingBufferPool[minIdx].size) {
+ minIdx = i;
+ }
+ }
+ this.readStagingBufferPool.splice(minIdx, 1)[0].buffer.destroy();
+ }
+ this.readStagingBufferPool.push({ buffer: buf, size: buf.size });
+ }
+
private deviceCopyFromGPU(
from: GPUPointer,
fromOffset: number,
to: Pointer,
nbytes: number
): void {
- // Perhaps it would be more useful to resuse a staging buffer?
- const gpuTemp = tryCreateBuffer(this.device, {
- size: nbytes,
- usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
- });
+ // Flush batched compute passes before the readback copy.
+ this.flushCommands();
+ const gpuTemp = this.getOrCreateReadStagingBuffer(nbytes);
const copyEncoder = this.device.createCommandEncoder();
copyEncoder.copyBufferToBuffer(
@@ -903,11 +1006,15 @@ export class WebGPUContext {
const copyCommands = copyEncoder.finish();
this.device.queue.submit([copyCommands]);
- gpuTemp.mapAsync(GPUMapMode.READ).then(() => {
- const data = gpuTemp.getMappedRange();
+ const readPromise = gpuTemp.mapAsync(GPUMapMode.READ).then(() => {
+ const data = gpuTemp.getMappedRange(0, nbytes);
this.memory.storeRawBytes(to, new Uint8Array(data));
- gpuTemp.destroy();
+ this.recycleReadStagingBuffer(gpuTemp);
});
+ // Chain with any existing pending read so sync() awaits all of them.
+ this.pendingGPUToCPUCopy = this.pendingGPUToCPUCopy
+ ? this.pendingGPUToCPUCopy.then(() => readPromise)
+ : readPromise;
}
private deviceCopyWithinGPU(
@@ -917,6 +1024,8 @@ export class WebGPUContext {
toOffset: number,
nbytes: number
): void {
+ // Flush batched compute passes before the GPU-to-GPU copy.
+ this.flushCommands();
const copyEncoder = this.device.createCommandEncoder();
copyEncoder.copyBufferToBuffer(
this.gpuBufferFromPtr(from),