This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 939b8b9ce7 [Web] Seperate parallel shard download and iterative shard 
loading (#16650)
939b8b9ce7 is described below

commit 939b8b9ce7e7f2b6289e883a7040b19cddb28636
Author: Hangrui Cao <[email protected]>
AuthorDate: Thu Mar 14 12:13:21 2024 -0400

    [Web] Seperate parallel shard download and iterative shard loading (#16650)
    
    * Fix Parallel Download Issue by seperating the downloading with 
serialization process
    
    Co-authored-by: Charlie Ruan 
<[email protected]>
    
    * Fix callback disply
    
    * [Web] Support IndexDB Caching
    
    * Limit max concurrent download to 4 shards
    
    * Try to catch error when loading model to ndarray cache
    
    
    ---------
    
    Co-authored-by: Charlie Ruan 
<[email protected]>
---
 web/src/artifact_cache.ts |   5 ++
 web/src/runtime.ts        | 133 +++++++++++++++++++++++++++++++++-------------
 2 files changed, 101 insertions(+), 37 deletions(-)

diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts
index ffb5011324..da9aaddfb0 100644
--- a/web/src/artifact_cache.ts
+++ b/web/src/artifact_cache.ts
@@ -25,6 +25,11 @@ export interface ArtifactCacheTemplate {
    */
   fetchWithCache(url: string);
 
+  /**
+   * add ey url to cache
+   */
+  addToCache(url: string);
+
   /**
    * check if cache has all keys in Cache
    */
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 8df48c43a5..ea022d1b3e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -165,6 +165,7 @@ class RuntimeContext implements Disposable {
   makeShapeTuple: PackedFunc;
   ndarrayCreateView: PackedFunc;
   sampleTopPFromLogits: PackedFunc;
+  sampleTopPFromProb: PackedFunc;
   applyRepetitionPenalty: PackedFunc;
   applyPresenceAndFrequencyPenalty: PackedFunc;
   applySoftmaxWithTemperature: PackedFunc;
@@ -188,6 +189,7 @@ class RuntimeContext implements Disposable {
     this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
     this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
     this.sampleTopPFromLogits = 
getGlobalFunc("vm.builtin.sample_top_p_from_logits");
+    this.sampleTopPFromProb = 
getGlobalFunc("vm.builtin.sample_top_p_from_prob");
     this.applyRepetitionPenalty = 
getGlobalFunc("vm.builtin.apply_repetition_penalty");
     this.applyPresenceAndFrequencyPenalty = 
getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
     this.applySoftmaxWithTemperature = 
getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
@@ -1020,6 +1022,17 @@ export class ArtifactCache implements 
ArtifactCacheTemplate {
     return result;
   }
 
+  async addToCache(url: string) {
+    const request = new Request(url);
+    if (this.cache === undefined) {
+      this.cache = await caches.open(this.scope);
+    }
+    const result = await this.cache.match(request);
+    if (result === undefined) {
+      await this.cache.add(request);
+    }
+  }
+
   async hasAllKeys(keys: string[]) {
     if (this.cache === undefined) {
       this.cache = await caches.open(this.scope);
@@ -1534,20 +1547,24 @@ export class Instance implements Disposable {
 
     const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new 
URL(key.dataPath, ndarrayCacheUrl).href))
 
-    const reportCallback = (iter: number) => {
+    const reportCallback = (iter: number, loading = false) => {
       // report
       for (let j = 0; j < this.initProgressCallback.length; ++j) {
-        let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
-        text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB 
fetched. "
-        text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% 
completed, "
-        text += timeElapsed + " secs elapsed.";
-        text += " It can take a while when we first visit this page to 
populate the cache."
-        text += " Later refreshes will become faster.";
-        if (cacheOnly) {
+        let text: string;
+        if (loading) {
+          text = "Finished fetching params, loading onto WebGPU.";
+        } else if (cacheOnly) {
           text = "Loading model from cache[" + iter + "/" + list.length + "]: 
";
           text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB 
loaded. "
           text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% 
completed, "
           text += timeElapsed + " secs elapsed.";
+        } else {
+          text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
+          text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB 
fetched. "
+          text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% 
completed, "
+          text += timeElapsed + " secs elapsed.";
+          text += " It can take a while when we first visit this page to 
populate the cache."
+          text += " Later refreshes will become faster.";
         }
         this.initProgressCallback[j]({
           progress: fetchedBytes / totalBytes,
@@ -1567,7 +1584,35 @@ export class Instance implements Disposable {
       });
     }
 
-    const processShard = async (i: number) => {
+    // First download all shards to cache parallely if not yet in cache
+    const downloadCache = async (start: number, end: number) => {
+      // Download params [start, end) from `list`
+      for (let i = start; i < end; i++) {
+        const shard = list[i];
+        const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
+        try {
+          await artifactCache.addToCache(dataUrl);
+        } catch (err) {
+          this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
+          throw err;
+        }
+        timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
+        fetchedBytes += shard.nbytes;
+        reportCallback(fetchedShards++);
+      }
+    }
+    // We launch 4 parallel for loops to limit the max concurrency to 4 
download
+    const loopSize = Math.floor(list.length / 4);
+    await Promise.all([
+      downloadCache(0, loopSize),
+      downloadCache(loopSize, 2 * loopSize),
+      downloadCache(2 * loopSize, 3 * loopSize),
+      downloadCache(3 * loopSize, list.length)
+    ]);
+    reportCallback(list.length, /*loading=*/true);
+
+    // Then iteratively, load the shard from cache
+    for (let i = 0; i < list.length; ++i) {
       const shard = list[i];
       const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
       let buffer;
@@ -1579,39 +1624,42 @@ export class Instance implements Disposable {
       }
       const shardRecords = shard.records;
       for (let j = 0; j < shardRecords.length; ++j) {
-        const rec = shardRecords[j];
-        const cpu_arr = this.withNewScope(() => {
-          return this.detachFromCurrentScope(
-            this.empty(rec.shape, rec.dtype, this.cpu())
-          )
-        });
-        const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + 
rec.nbytes);
-        // first sync copy to cpu.
-        this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), 
rec.format, rec.dtype);
-        // then async stream into GPU if needed
-        if (device.deviceType === DeviceStrToEnum.cpu) {
-          this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
-          cpu_arr.dispose();
-        } else {
-          // allocate a gpu arr and async copy to it.
-          const gpu_arr = this.withNewScope(() => {
+        try {
+          const rec = shardRecords[j];
+          const cpu_arr = this.withNewScope(() => {
             return this.detachFromCurrentScope(
-              this.empty(rec.shape, rec.dtype, device)
+              this.empty(rec.shape, rec.dtype, this.cpu())
             )
           });
-          gpu_arr.copyFrom(cpu_arr);
-          await device.sync();
-          this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
-          cpu_arr.dispose();
-          gpu_arr.dispose();
+          const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + 
rec.nbytes);
+          // first sync copy to cpu.
+          this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), 
rec.format, rec.dtype);
+          // then async stream into GPU if needed
+          if (device.deviceType === DeviceStrToEnum.cpu) {
+            this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
+            cpu_arr.dispose();
+          } else {
+            // allocate a gpu arr and async copy to it.
+            const gpu_arr = this.withNewScope(() => {
+              return this.detachFromCurrentScope(
+                this.empty(rec.shape, rec.dtype, device)
+              )
+            });
+            gpu_arr.copyFrom(cpu_arr);
+            await device.sync();
+            this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
+            cpu_arr.dispose();
+            gpu_arr.dispose();
+          }
+        } catch (err) {
+          this.env.logger(
+            "Failed to load shard " + i + "'s record: " + 
JSON.stringify(shardRecords[j]) + "\n" +
+            "Error: " + err
+          );
+          throw err;
         }
       }
-      timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
-      fetchedBytes += shard.nbytes;
-      reportCallback(fetchedShards++);
     }
-    await Promise.all(list.map((_, index) => processShard(index)));
-    reportCallback(list.length);
   }
 
   /**
@@ -1780,6 +1828,17 @@ export class Instance implements Disposable {
     return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, 
Math.random());
   }
 
+  /**
+   * Sample index via top-p sampling.
+   *
+   * @param prob The distribution, i.e. logits after 
`applySoftmaxWithTemperature()` is performed.
+   * @param top_p The top_p
+   * @returns The sampled index.
+   */
+  sampleTopPFromProb(prob: NDArray, top_p: number): number {
+    return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
+  }
+
   /**
    * Apply repetition penalty to the logits.
    * @param logits The input logits before penalty.
@@ -2549,7 +2608,7 @@ export async function deleteNDArrayCache(
   const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href;
   const result = await artifactCache.fetchWithCache(jsonUrl);
   let list;
-  if (result instanceof Response){
+  if (result instanceof Response) {
     list = await result.json();
   }
   const arrayentry = list["records"] as Array<NDArrayShardEntry>;

Reply via email to