This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 939b8b9ce7 [Web] Seperate parallel shard download and iterative shard
loading (#16650)
939b8b9ce7 is described below
commit 939b8b9ce7e7f2b6289e883a7040b19cddb28636
Author: Hangrui Cao <[email protected]>
AuthorDate: Thu Mar 14 12:13:21 2024 -0400
[Web] Seperate parallel shard download and iterative shard loading (#16650)
* Fix Parallel Download Issue by seperating the downloading with
serialization process
Co-authored-by: Charlie Ruan
<[email protected]>
* Fix callback disply
* [Web] Support IndexDB Caching
* Limit max concurrent download to 4 shards
* Try to catch error when loading model to ndarray cache
---------
Co-authored-by: Charlie Ruan
<[email protected]>
---
web/src/artifact_cache.ts | 5 ++
web/src/runtime.ts | 133 +++++++++++++++++++++++++++++++++-------------
2 files changed, 101 insertions(+), 37 deletions(-)
diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts
index ffb5011324..da9aaddfb0 100644
--- a/web/src/artifact_cache.ts
+++ b/web/src/artifact_cache.ts
@@ -25,6 +25,11 @@ export interface ArtifactCacheTemplate {
*/
fetchWithCache(url: string);
+ /**
+ * add ey url to cache
+ */
+ addToCache(url: string);
+
/**
* check if cache has all keys in Cache
*/
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 8df48c43a5..ea022d1b3e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -165,6 +165,7 @@ class RuntimeContext implements Disposable {
makeShapeTuple: PackedFunc;
ndarrayCreateView: PackedFunc;
sampleTopPFromLogits: PackedFunc;
+ sampleTopPFromProb: PackedFunc;
applyRepetitionPenalty: PackedFunc;
applyPresenceAndFrequencyPenalty: PackedFunc;
applySoftmaxWithTemperature: PackedFunc;
@@ -188,6 +189,7 @@ class RuntimeContext implements Disposable {
this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits =
getGlobalFunc("vm.builtin.sample_top_p_from_logits");
+ this.sampleTopPFromProb =
getGlobalFunc("vm.builtin.sample_top_p_from_prob");
this.applyRepetitionPenalty =
getGlobalFunc("vm.builtin.apply_repetition_penalty");
this.applyPresenceAndFrequencyPenalty =
getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
this.applySoftmaxWithTemperature =
getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
@@ -1020,6 +1022,17 @@ export class ArtifactCache implements
ArtifactCacheTemplate {
return result;
}
+ async addToCache(url: string) {
+ const request = new Request(url);
+ if (this.cache === undefined) {
+ this.cache = await caches.open(this.scope);
+ }
+ const result = await this.cache.match(request);
+ if (result === undefined) {
+ await this.cache.add(request);
+ }
+ }
+
async hasAllKeys(keys: string[]) {
if (this.cache === undefined) {
this.cache = await caches.open(this.scope);
@@ -1534,20 +1547,24 @@ export class Instance implements Disposable {
const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new
URL(key.dataPath, ndarrayCacheUrl).href))
- const reportCallback = (iter: number) => {
+ const reportCallback = (iter: number, loading = false) => {
// report
for (let j = 0; j < this.initProgressCallback.length; ++j) {
- let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
- text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
fetched. "
- text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
completed, "
- text += timeElapsed + " secs elapsed.";
- text += " It can take a while when we first visit this page to
populate the cache."
- text += " Later refreshes will become faster.";
- if (cacheOnly) {
+ let text: string;
+ if (loading) {
+ text = "Finished fetching params, loading onto WebGPU.";
+ } else if (cacheOnly) {
text = "Loading model from cache[" + iter + "/" + list.length + "]:
";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
loaded. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
completed, "
text += timeElapsed + " secs elapsed.";
+ } else {
+ text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
+ text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
fetched. "
+ text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
completed, "
+ text += timeElapsed + " secs elapsed.";
+ text += " It can take a while when we first visit this page to
populate the cache."
+ text += " Later refreshes will become faster.";
}
this.initProgressCallback[j]({
progress: fetchedBytes / totalBytes,
@@ -1567,7 +1584,35 @@ export class Instance implements Disposable {
});
}
- const processShard = async (i: number) => {
+ // First download all shards to cache parallely if not yet in cache
+ const downloadCache = async (start: number, end: number) => {
+ // Download params [start, end) from `list`
+ for (let i = start; i < end; i++) {
+ const shard = list[i];
+ const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
+ try {
+ await artifactCache.addToCache(dataUrl);
+ } catch (err) {
+ this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
+ throw err;
+ }
+ timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
+ fetchedBytes += shard.nbytes;
+ reportCallback(fetchedShards++);
+ }
+ }
+ // We launch 4 parallel for loops to limit the max concurrency to 4
download
+ const loopSize = Math.floor(list.length / 4);
+ await Promise.all([
+ downloadCache(0, loopSize),
+ downloadCache(loopSize, 2 * loopSize),
+ downloadCache(2 * loopSize, 3 * loopSize),
+ downloadCache(3 * loopSize, list.length)
+ ]);
+ reportCallback(list.length, /*loading=*/true);
+
+ // Then iteratively, load the shard from cache
+ for (let i = 0; i < list.length; ++i) {
const shard = list[i];
const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
let buffer;
@@ -1579,39 +1624,42 @@ export class Instance implements Disposable {
}
const shardRecords = shard.records;
for (let j = 0; j < shardRecords.length; ++j) {
- const rec = shardRecords[j];
- const cpu_arr = this.withNewScope(() => {
- return this.detachFromCurrentScope(
- this.empty(rec.shape, rec.dtype, this.cpu())
- )
- });
- const recSource = buffer.slice(rec.byteOffset, rec.byteOffset +
rec.nbytes);
- // first sync copy to cpu.
- this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource),
rec.format, rec.dtype);
- // then async stream into GPU if needed
- if (device.deviceType === DeviceStrToEnum.cpu) {
- this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
- cpu_arr.dispose();
- } else {
- // allocate a gpu arr and async copy to it.
- const gpu_arr = this.withNewScope(() => {
+ try {
+ const rec = shardRecords[j];
+ const cpu_arr = this.withNewScope(() => {
return this.detachFromCurrentScope(
- this.empty(rec.shape, rec.dtype, device)
+ this.empty(rec.shape, rec.dtype, this.cpu())
)
});
- gpu_arr.copyFrom(cpu_arr);
- await device.sync();
- this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
- cpu_arr.dispose();
- gpu_arr.dispose();
+ const recSource = buffer.slice(rec.byteOffset, rec.byteOffset +
rec.nbytes);
+ // first sync copy to cpu.
+ this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource),
rec.format, rec.dtype);
+ // then async stream into GPU if needed
+ if (device.deviceType === DeviceStrToEnum.cpu) {
+ this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
+ cpu_arr.dispose();
+ } else {
+ // allocate a gpu arr and async copy to it.
+ const gpu_arr = this.withNewScope(() => {
+ return this.detachFromCurrentScope(
+ this.empty(rec.shape, rec.dtype, device)
+ )
+ });
+ gpu_arr.copyFrom(cpu_arr);
+ await device.sync();
+ this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
+ cpu_arr.dispose();
+ gpu_arr.dispose();
+ }
+ } catch (err) {
+ this.env.logger(
+ "Failed to load shard " + i + "'s record: " +
JSON.stringify(shardRecords[j]) + "\n" +
+ "Error: " + err
+ );
+ throw err;
}
}
- timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
- fetchedBytes += shard.nbytes;
- reportCallback(fetchedShards++);
}
- await Promise.all(list.map((_, index) => processShard(index)));
- reportCallback(list.length);
}
/**
@@ -1780,6 +1828,17 @@ export class Instance implements Disposable {
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p,
Math.random());
}
+ /**
+ * Sample index via top-p sampling.
+ *
+ * @param prob The distribution, i.e. logits after
`applySoftmaxWithTemperature()` is performed.
+ * @param top_p The top_p
+ * @returns The sampled index.
+ */
+ sampleTopPFromProb(prob: NDArray, top_p: number): number {
+ return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
+ }
+
/**
* Apply repetition penalty to the logits.
* @param logits The input logits before penalty.
@@ -2549,7 +2608,7 @@ export async function deleteNDArrayCache(
const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href;
const result = await artifactCache.fetchWithCache(jsonUrl);
let list;
- if (result instanceof Response){
+ if (result instanceof Response) {
list = await result.json();
}
const arrayentry = list["records"] as Array<NDArrayShardEntry>;