This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 678d01dd4a [Unity][WEB] Relax vm on web runtime (#14131)
678d01dd4a is described below

commit 678d01dd4a4e75ef6186ce356bb1a20e584a7b24
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 25 13:22:10 2023 -0500

    [Unity][WEB] Relax vm on web runtime (#14131)
    
    This PR brings initial relax vm support on web runtime
---
 include/tvm/runtime/relax_vm/vm.h                  |   4 +
 python/tvm/contrib/tvmjs.py                        | 119 ++++++++
 python/tvm/exec/rpc_proxy.py                       |  32 ++-
 python/tvm/relax/vm_build.py                       |  14 +-
 python/tvm/rpc/proxy.py                            |  21 +-
 src/runtime/relax_vm/vm.cc                         |  11 +
 web/.gitignore                                     |   1 +
 web/apps/browser/rpc_server.html                   |  65 ++++-
 web/emcc/wasm_runtime.cc                           |  74 +++++
 web/src/rpc_server.ts                              |  29 +-
 web/src/runtime.ts                                 | 315 +++++++++++++++++++--
 web/tests/node/test_relax_vm.js                    |  67 +++++
 web/tests/python/prepare_test_libs.py              |  30 +-
 .../{webgpu_rpc_test.py => relax_rpc_test.py}      |  79 +++---
 web/tests/python/webgpu_rpc_test.py                |   4 +-
 web/tests/python/websock_rpc_test.py               |   4 +-
 16 files changed, 780 insertions(+), 89 deletions(-)

diff --git a/include/tvm/runtime/relax_vm/vm.h 
b/include/tvm/runtime/relax_vm/vm.h
index d39de74f2d..bd59106cc1 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -23,6 +23,10 @@
 #ifndef TVM_RUNTIME_RELAX_VM_VM_H_
 #define TVM_RUNTIME_RELAX_VM_VM_H_
 
+#ifndef TVM_RELAX_VM_ENABLE_PROFILER
+#define TVM_RELAX_VM_ENABLE_PROFILER 1
+#endif
+
 #include <memory>
 #include <string>
 #include <vector>
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
new file mode 100644
index 0000000000..18cbf332c8
--- /dev/null
+++ b/python/tvm/contrib/tvmjs.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Namespace to store utilities for building web runtime."""
+# pylint: disable=unused-import
+import sys
+import os
+import json
+from typing import Mapping, Union
+
+import numpy as np
+
+import tvm
+from .emcc import create_tvmjs_wasm
+
+
+def _convert_f32_to_bf16(value):
+    cap = np.finfo("float32").max
+    assert -np.finfo("float32").max == np.finfo("float32").min
+    bf16_limit = ((np.array([cap.view("uint32")]) >> 16) << 
16).view("float32")[0]
+    # When the value is in [-bf16_limit, bf16_limit], round to nearest even.
+    # We can afford to do it in dumping phase to reduce overall rounding error.
+    #
+    # When the value is out of bound(usually mask values in attention), use 
truncation
+    # so it is equivalent to clip to the limit values
+    data = value.view("uint32")
+    rounding_bias = np.where(
+        np.logical_and(value < bf16_limit, value > -bf16_limit),
+        ((data >> 16) & 1) + 0x7FFF,
+        np.zeros_like(data),
+    )
+    return ((data + rounding_bias) >> 16).astype("uint16")
+
+
+def dump_ndarray_cache(
+    params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+    cachedir: str,
+    encode_format="f32-to-bf16",
+):
+    """Dump parameters to NDArray cache.
+
+    Parameters
+    ----------
+    params: Mapping[str, tvm.runtime.NDArray],
+        The parameter dictionary
+
+    cachedir: str
+        The path to the cache
+
+    encode_format: {"f32-to-bf16", "raw"}
+        Encoding format.
+    """
+    records = []
+    total = len(params)
+    counter = 0
+    max_out_length = 0
+
+    if not os.path.exists(cachedir):
+        os.makedirs(cachedir)
+
+    f32_to_bf16_triggered = False
+
+    print("Start storing to cache %s" % cachedir)
+    for k, v in params.items():
+        fname = k + ".bin"
+        out_path = os.path.join(cachedir, fname)
+        shape = list(v.shape)
+
+        if not isinstance(v, np.ndarray):
+            v = v.numpy()
+
+        # convert fp32 to bf16
+        if encode_format == "f32-to-bf16" and v.dtype == "float32":
+            _convert_f32_to_bf16(v).tofile(out_path)
+            dtype = "bfloat16"
+            f32_to_bf16_triggered = True
+        else:
+            v.tofile(out_path)
+
+        dtype = str(v.dtype)
+        records.append(
+            {"name": k, "shape": shape, "dtype": dtype, "dataPath": fname, 
"format": encode_format}
+        )
+        counter += 1
+        last_cmd = "[%04d/%04d] saving %s" % (counter, total, out_path)
+        flush = "\r" + (" " * max_out_length) + "\r"
+        max_out_length = max(len(last_cmd), max_out_length)
+        sys.stdout.write(flush + last_cmd)
+
+    nd_cache_json = os.path.join(cachedir, "ndarray-cache.json")
+    with open(nd_cache_json, "w") as outfile:
+        json.dump(records, outfile, indent=4)
+    print("\nAll finished, record saved to %s" % nd_cache_json)
+
+    if f32_to_bf16_triggered:
+        rec_bf16 = []
+        for item in records:
+            if item["dtype"] == "float32":
+                item["format"] = "raw"
+                item["dtype"] = "bfloat16"
+            rec_bf16.append(item)
+        b16_nd_cache_json = os.path.join(cachedir, "ndarray-cache-b16.json")
+        # also dump a file that contains bf16
+        with open(b16_nd_cache_json, "w") as outfile:
+            json.dump(rec_bf16, outfile, indent=4)
+        print("Also saved a bf16 record to %s" % b16_nd_cache_json)
diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index 7eae4fe174..d340750785 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -19,6 +19,7 @@
 import logging
 import argparse
 import os
+import glob
 from tvm.rpc.proxy import Proxy
 
 
@@ -28,16 +29,29 @@ def find_example_resource():
     base_path = os.path.abspath(os.path.join(curr_path, "..", "..", ".."))
     index_page = os.path.join(base_path, "web", "apps", "browser", 
"rpc_server.html")
     resource_files = [
-        os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"),
-        os.path.join(base_path, "web", "dist", "wasm", 
"tvmjs_runtime.wasi.js"),
+        ("/", os.path.join(base_path, "web", "dist", "tvmjs.bundle.js")),
+        ("/", os.path.join(base_path, "web", "dist", "wasm", 
"tvmjs_runtime.wasi.js")),
+        ("/", index_page),
     ]
-    resource_base = os.path.join(base_path, "web", "dist", "www")
-    if os.path.isdir(resource_base):
-        for fname in os.listdir(resource_base):
-            full_name = os.path.join(resource_base, fname)
-            if os.path.isfile(full_name):
-                resource_files.append(full_name)
-    for fname in [index_page] + resource_files:
+    allow_format = ("json", "bin", "js", "wasm")
+
+    # recursively apend things in www, up to two levels
+    resource_bases = [
+        os.path.join(base_path, "web", "dist", "www"),
+        os.path.join(base_path, "web", ".ndarray_cache"),
+    ]
+    for base in resource_bases:
+        if not os.path.isdir(base):
+            continue
+        for full_name in glob.glob("%s/**" % base, recursive=True):
+            fname = os.path.relpath(full_name, base)
+            dirname = os.path.dirname(fname)
+            fmt = fname.rsplit(".", 1)[-1]
+            if os.path.isfile(full_name) and fmt in allow_format:
+                resource_files.append((dirname, full_name))
+
+    for item in resource_files:
+        fname = item[-1]
         if not os.path.exists(fname):
             raise RuntimeError("Cannot find %s" % fname)
     return index_page, resource_files
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 35fc65bdc6..0586bf9217 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -180,6 +180,18 @@ def _vmcodegen(
     raise ValueError("Unknown exec_mode %s" % exec_mode)
 
 
+def _autodetect_system_lib_req(target: tvm.target.Target):
+    """Automatically detect system lib requirement"""
+    host = target if target.host is None else target.host
+    system_lib = False
+    if "wasm" in host.attrs.get("mtriple", ""):
+        system_lib = True
+    if system_lib:
+        # use packed-func to avoid relay dep.
+        return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", 
{"system-lib": system_lib})
+    return None
+
+
 def _vmlink(
     builder: "relax.ExecBuilder",
     target: Union[str, tvm.target.Target],
@@ -224,7 +236,7 @@ def _vmlink(
         ext_libs = []
     lib = None
     if tir_mod is not None:
-        lib = tvm.build(tir_mod, target=target)
+        lib = tvm.build(tir_mod, target=target, 
runtime=_autodetect_system_lib_req(target))
     return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) 
 # type: ignore
 
 
diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py
index d7027c88a4..59af53d4e1 100644
--- a/python/tvm/rpc/proxy.py
+++ b/python/tvm/rpc/proxy.py
@@ -203,11 +203,20 @@ class WebSocketHandler(websocket.WebSocketHandler, 
ForwardHandler):
         self.close()
 
 
+MIME_MAP = {
+    "js": "application/javascript",
+    "wasm": "application/wasm",
+    "json": "application/json",
+}
+
+
 class RequestHandler(tornado.web.RequestHandler):
     """Handles html request."""
 
     def __init__(self, *args, **kwargs):
         file_path = kwargs.pop("file_path")
+        self.format = file_path.split(".")[-1]
+
         if file_path.endswith("html"):
             self.page = open(file_path).read()
             web_port = kwargs.pop("rpc_web_port", None)
@@ -217,12 +226,15 @@ class RequestHandler(tornado.web.RequestHandler):
                 )
         else:
             self.page = open(file_path, "rb").read()
+
         super(RequestHandler, self).__init__(*args, **kwargs)
 
     def data_received(self, _):
         pass
 
     def get(self, *args, **kwargs):
+        if self.format in MIME_MAP:
+            self.set_header("Content-Type", MIME_MAP[self.format])
         self.write(self.page)
 
 
@@ -254,9 +266,14 @@ class ProxyServerHandler(object):
                 )
                 logging.info("Serving RPC index html page at 
http://localhost:%d";, web_port)
             resource_files = resource_files if resource_files else []
-            for fname in resource_files:
+            for item in resource_files:
+                prefix, fname = item
+                if not prefix.endswith("/"):
+                    prefix += "/"
+                if not prefix.startswith("/"):
+                    prefix = "/" + prefix
                 basename = os.path.basename(fname)
-                pair = (r"/%s" % basename, RequestHandler, {"file_path": 
fname})
+                pair = (r"%s%s" % (prefix, basename), RequestHandler, 
{"file_path": fname})
                 handlers.append(pair)
                 logging.info(pair)
             self.app = tornado.web.Application(handlers)
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 3b952c1ff5..8679b2a793 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -827,6 +827,11 @@ void VirtualMachineImpl::RunLoop() {
 
 ObjectPtr<VirtualMachine> VirtualMachine::Create() { return 
make_object<VirtualMachineImpl>(); }
 
+//----------------------------------------------------------------
+// Profiler can be optionally disabled via a macro to reduce dep.
+//----------------------------------------------------------------
+#if TVM_RELAX_VM_ENABLE_PROFILER
+
 /*!
  * \brief An extension of VirtualMachineImpl to support per-op profiling
  * It overrides RunInstrCall to add instrumentations around it.
@@ -927,6 +932,12 @@ ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() 
{
   return make_object<VirtualMachineProfiler>();
 }
 
+#else
+ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() {
+  LOG(FATAL) << "Profiler support is disabled";
+  return nullptr;
+}
+#endif  // TVM_RELAX_VM_ENABLE_PROFILER
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
diff --git a/web/.gitignore b/web/.gitignore
index 1f7cc0916a..69bf96a8a7 100644
--- a/web/.gitignore
+++ b/web/.gitignore
@@ -4,3 +4,4 @@ out
 node_modules
 build
 debug
+.ndarray_cache
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index 6d353e29b0..8fa50272b2 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -15,38 +15,71 @@
   <!--- KIND, either express or implied.  See the License for the -->
   <!--- specific language governing permissions and limitations -->
   <!--- under the License. -->
+  <!DOCTYPE html>
 
-  <head>
+  <head lang="en-US"></head>
     <title>TVM RPC Test Page</title>
   </head>
   <script src="tvmjs_runtime.wasi.js"></script>
   <script src="tvmjs.bundle.js"></script>
   <script>
-
     function customLog(message) {
       console.log(message);
       const d = document.createElement("div");
       d.innerHTML = message;
       document.getElementById("log").appendChild(d);
     };
+
     function clearLog() {
       const node = document.getElementById("log");
       while (node.hasChildNodes()) {
         node.removeChild(node.lastChild);
       }
     }
+
+    function fetchProgressCallback(report) {
+      document.getElementById("fetch-text").innerHTML = report.text;
+      document.getElementById("fetch-progress").value = (report.fetchedBytes / 
report.totalBytes) * 100;
+    }
+
     function connectRPC() {
-      const proxyUrl = document.getElementById("proxyURL").value;
+      const proxyUrl = document.getElementById("proxyUrl").value;
       const key = document.getElementById("proxyKey").value;
+      const ndarrayCacheName = document.getElementById("cache-select").value;
+      let ndarrayCacheUrl = new URL(ndarrayCacheName + "/", document.URL).href;
+      let ndarrayCacheDevice = 
document.getElementById("ndarrayCacheDevice").value;
+
+      if (ndarrayCacheName == "none" || ndarrayCacheName === undefined) {
+        ndarrayCacheUrl = "";
+      }
+
       // only works for once.
       const getImports = () => {
         return new EmccWASI();
       };
 
-      new tvmjs.RPCServer(proxyUrl, key, getImports, customLog);
+      new tvmjs.RPCServer(
+        proxyUrl, key, getImports, customLog,
+        ndarrayCacheUrl, ndarrayCacheDevice, fetchProgressCallback);
+    }
+
+    async function loadCacheOption() {
+      const select = document.getElementById("cache-select");
+      try {
+        const list = await (await fetch("/cache-list.json")).json()
+        for (let i = 0; i < list.length; ++i) {
+          const option = document.createElement("option");
+          option.text = list[i];
+          option.value = list[i];
+          select.add(option);
+        }
+        if (list.length != 0) {
+          select.value = list[0];
+        }
+      } catch (err) {}
     }
   </script>
-  <body>
+  <body onload="loadCacheOption()">
     <h1>TVM WebSocket RPC Server</h1>
     To use this page
     <ul>
@@ -59,20 +92,34 @@
     </ul>
 
     <h2>Options</h2>
-    Proxy URL<input
-      name="proxyurl"
-      id="proxyURL"
+    Proxy URL <input
+      name="proxyrl"
+      id="proxyUrl"
       type="text"
       value="ws://localhost:8888/ws"
     /><br />
-    RPC Server Key<input
+    RPC Server Key <input
       name="serverkey"
       id="proxyKey"
       type="text"
       value="wasm"
     /><br />
+    NDArrayCache -
+    <select name="cache-name" id="cache-select">
+      <option value="none">none</option>
+    </select>
+    CacheDevice -
+    <select name="cache-device" id="ndarrayCacheDevice">
+      <option value="webgpu">webgpu</option>
+      <option value="cpu">cpu</option>
+    </select>
+    <br />
     <button onclick="connectRPC()">Connect To Proxy</button>
     <button onclick="clearLog()">Clear Log</button>
+    <div id="progress">
+      <label id="fetch-text"></div>
+      <progress id="fetch-progress" max="100" value="100"> </progress>
+    </div>
     <div id="log"></div>
     <canvas id="canvas" width="224" height="224"></canvas>
   </body>
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 00d2a8c579..c90b917c5c 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -26,6 +26,7 @@
 #define TVM_LOG_STACK_TRACE 0
 #define TVM_LOG_DEBUG 0
 #define TVM_LOG_CUSTOMIZE 1
+
 #define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
 
 #include <tvm/runtime/c_runtime_api.h>
@@ -51,6 +52,12 @@
 #include "src/runtime/rpc/rpc_session.cc"
 #include "src/runtime/system_library.cc"
 #include "src/runtime/workspace_pool.cc"
+// relax setup
+#include "src/runtime/relax_vm/builtin.cc"
+#include "src/runtime/relax_vm/bytecode.cc"
+#include "src/runtime/relax_vm/executable.cc"
+#include "src/runtime/relax_vm/memory_manager.cc"
+#include "src/runtime/relax_vm/vm.cc"
 
 // --- Implementations of backend and wasm runtime API. ---
 
@@ -111,5 +118,72 @@ 
TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
   // and get another value.
   *ret = (obj.use_count() - 1);
 });
+
+/*!
+ * A NDArray cache to store pre-loaded arrays in the system.
+ */
+class NDArrayCache {
+ public:
+  static NDArrayCache* Global() {
+    static NDArrayCache* inst = new NDArrayCache();
+    return inst;
+  }
+
+  static void Update(String name, NDArray arr, bool override) {
+    NDArrayCache* pool = Global();
+    if (!override) {
+      ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already 
exists in the cache";
+    }
+    pool->pool_.Set(name, arr);
+  }
+
+  static Optional<NDArray> Get(String name) {
+    NDArrayCache* pool = Global();
+    auto it = pool->pool_.find(name);
+    if (it != pool->pool_.end()) {
+      return (*it).second;
+    } else {
+      return NullOpt;
+    }
+  }
+
+  static void Remove(String name) {
+    NDArrayCache* pool = Global();
+    pool->pool_.erase(name);
+  }
+
+  static void Clear() { Global()->pool_.clear(); }
+
+ private:
+  Map<String, NDArray> pool_;
+};
+
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.get").set_body_typed(NDArrayCache::Get);
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.update").set_body_typed(NDArrayCache::Update);
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove);
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear);
+
+void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string 
format) {
+  if (format == "f32-to-bf16") {
+    std::vector<uint16_t> buffer(bytes.length() / 2);
+    std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2);
+    // decode bf16 to f32
+    const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(buffer.data());
+    uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
+    ICHECK(cpu_arr.IsContiguous());
+    size_t size = 1;
+    for (int i = 0; i < cpu_arr->ndim; ++i) {
+      size *= cpu_arr->shape[i];
+    }
+    ICHECK_EQ(size, bytes.length() / 2);
+    for (size_t i = 0; i < size; ++i) {
+      data[i] = static_cast<uint32_t>(bf16[i]) << 16;
+    }
+  } else {
+    cpu_arr.CopyFromBytes(bytes.data(), bytes.length());
+  }
+}
+
+TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
 }  // namespace runtime
 }  // namespace tvm
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index e37d1838d6..4dd7228d3c 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -82,6 +82,9 @@ export class RPCServer {
   state: RPCServerState = RPCServerState.InitHeader;
   logger: (msg: string) => void;
   getImports: () => Record<string, unknown>;
+  private ndarrayCacheUrl: string;
+  private ndarrayCacheDevice: string;
+  private fetchProgressCallback?: runtime.FetchProgressCallback;
   private pendingSend: Promise<void> = Promise.resolve();
   private name: string;
   private inst?: runtime.Instance = undefined;
@@ -98,13 +101,19 @@ export class RPCServer {
     url: string,
     key: string,
     getImports: () => Record<string, unknown>,
-    logger: (msg: string) => void = console.log
+    logger: (msg: string) => void = console.log,
+    ndarrayCacheUrl: string = "",
+    ndarrayCacheDevice: string = "cpu",
+    fetchProgressCallback: runtime.FetchProgressCallback | undefined = 
undefined
   ) {
     this.url = url;
     this.key = key;
     this.name = "WebSocketRPCServer[" + this.key + "]: ";
     this.getImports = getImports;
     this.logger = logger;
+    this.ndarrayCacheUrl = ndarrayCacheUrl;
+    this.ndarrayCacheDevice = ndarrayCacheDevice;
+    this.fetchProgressCallback = fetchProgressCallback;
 
     this.checkLittleEndian();
     this.socket = compact.createWebSocket(url);
@@ -132,7 +141,9 @@ export class RPCServer {
     if (this.state == RPCServerState.ReceivePacketHeader) {
       this.log("Closing the server in clean state");
       this.log("Automatic reconnecting..");
-      new RPCServer(this.url, this.key, this.getImports, this.logger);
+      new RPCServer(
+        this.url, this.key, this.getImports, this.logger,
+        this.ndarrayCacheUrl, this.ndarrayCacheDevice, 
this.fetchProgressCallback);
     } else {
       this.log("Closing the server, final state=" + this.state);
     }
@@ -272,6 +283,20 @@ export class RPCServer {
       // begin scope to allow handling of objects
       // the object should stay alive during all sessions.
       this.inst.beginScope();
+      if (this.fetchProgressCallback !== undefined) {
+        this.inst.registerFetchProgressCallback(this.fetchProgressCallback);
+      }
+
+      if (this.ndarrayCacheUrl.length != 0) {
+        if (this.ndarrayCacheDevice == "cpu") {
+          await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, 
this.inst.cpu());
+        } else {
+          assert(this.ndarrayCacheDevice == "webgpu");
+          await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, 
this.inst.webgpu());
+        }
+      }
+
+      assert(this.inst !== undefined);
       const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer");
 
       const messageHandler = fcreate(
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index a24459ca29..463532762e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -29,7 +29,6 @@ import { WebGPUContext } from "./webgpu";
 
 import * as compact from "./compact";
 import * as ctypes from "./ctypes";
-import { tsImportEqualsDeclaration } from "@babel/types";
 
 /**
  * Type for PackedFunc inthe TVMRuntime.
@@ -144,6 +143,11 @@ class RuntimeContext implements Disposable {
   arrayGetSize : PackedFunc;
   arrayMake : PackedFunc;
   getSysLib: PackedFunc;
+  arrayCacheGet: PackedFunc;
+  arrayCacheUpdate: PackedFunc;
+  arrayCacheRemove: PackedFunc;
+  arrayCacheClear: PackedFunc;
+  arrayDecodeStorage: PackedFunc;
 
   private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
 
@@ -152,12 +156,25 @@ class RuntimeContext implements Disposable {
     this.arrayGetSize = getGlobalFunc("runtime.ArraySize");
     this.arrayMake = getGlobalFunc("runtime.Array");
     this.getSysLib = getGlobalFunc("runtime.SystemLib");
+    this.arrayCacheGet = getGlobalFunc("tvmjs.ndarray_cache.get");
+    this.arrayCacheRemove = getGlobalFunc("tvmjs.ndarray_cache.remove");
+    this.arrayCacheUpdate = getGlobalFunc("tvmjs.ndarray_cache.update");
+    this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear");
+    this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
+
   }
 
   dispose(): void {
+    // call array cache clear to clear all cached items
+    this.arrayCacheClear();
     this.arrayGetItem.dispose();
     this.arrayGetSize.dispose();
     this.arrayMake.dispose();
+    this.arrayCacheGet.dispose();
+    this.arrayCacheRemove.dispose();
+    this.arrayCacheUpdate.dispose();
+    this.arrayCacheClear.dispose();
+    this.arrayDecodeStorage.dispose();
   }
 
   beginScope() : void {
@@ -522,6 +539,9 @@ export class NDArray implements Disposable {
    * @returns this
    */
   copyFromRawBytes(data: Uint8Array): this {
+    if (this.device.deviceType != DeviceStrToEnum.cpu) {
+      throw new Error("Can only sync copy CPU array, use 
cpu_arr.copyfrom(gpu_arr) then sync instead.");
+    }
     const size = this.shape.reduce((a, b) => {
       return a * b;
     }, 1);
@@ -552,7 +572,7 @@ export class NDArray implements Disposable {
    */
   toRawBytes(): Uint8Array {
     if (this.device.deviceType != DeviceStrToEnum.cpu) {
-      throw new Error("Can only synchronize copy for GPU array, use copyfrom 
instead.");
+      throw new Error("Can only sync copy CPU array, use 
cpu_arr.copyfrom(gpu_arr) then sync instead.");
     }
     const size = this.shape.reduce((a, b) => {
       return a * b;
@@ -806,12 +826,70 @@ export class TVMArray extends TVMObject {
   }
 }
 
+export const enum VMAllocatorKind {
+  NAIVE_ALLOCATOR = 1,
+  POOLED_ALLOCATOR = 2,
+}
+
+/**
+ *  VirtualMachine Executor.
+ *
+ *  This is a thin wrapper of the underlying TVM module.
+ *  you can also directly call set_input, run, and get_output
+ *  of underlying module functions
+ */
+export class VirtualMachine implements Disposable {
+  private mod: Module;
+  /**
+   * Constructor
+   * @param mod The underlying module, need to be detached.
+   * @param device The main device ro run VM on.
+   */
+  constructor(mod: Module, device: DLDevice) {
+    this.mod = mod;
+    this.mod.getFunction("vm_initialization")(
+      new Scalar(device.deviceType, "int"),
+      new Scalar(device.deviceId, "int"),
+      new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int")
+    );
+  }
+
+  dispose(): void {
+    this.mod.dispose();
+  }
+  /**
+   * Get a function in the VM module.
+   * @param name The name of the function.
+   * @returns The result function.
+   */
+  getFunction(name: string): PackedFunc {
+    return this.mod.getFunction(name);
+  }
+}
+
 /** Code used as the first argument of the async callback. */
 const enum AyncCallbackCode {
   kReturn = 4,
   kException = 5,
 }
 
+export interface NDArrayCacheEntry {
+  name: string;
+  shape: Array<number>;
+  dtype: string;
+  format: "f32-to-bf16" | "raw";
+  dataPath: string;
+}
+
+export interface FetchProgressReport {
+  fetchedBytes: number;
+  totalBytes: number;
+  timeElapsed: number;
+  text: string;
+}
+
+export type FetchProgressCallback = (report: FetchProgressReport) => void;
+
 /**
  * TVM runtime instance.
  *
@@ -836,6 +914,7 @@ export class Instance implements Disposable {
   private env: Environment;
   private objFactory: Map<number, FObjectConstructor>;
   private ctx: RuntimeContext;
+  private fetchProgressCallback: Array<FetchProgressCallback> = [];
 
   /**
    * Internal function(registered by the runtime)
@@ -898,26 +977,26 @@ export class Instance implements Disposable {
    * @number The number of times to compute the average.
    * @repeat The number of times to repeat the run.
    */
-    async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4): 
Promise<number[]> {
-      // Skip first run as it can involve GPU warmup and module loading time.
-      const perf = compact.getPerformance();
-      const results = [];
+  async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4): 
Promise<number[]> {
+    // Skip first run as it can involve GPU warmup and module loading time.
+    const perf = compact.getPerformance();
+    const results = [];
 
-      // run with new scope
-      this.withNewScope(run);
-      await dev.sync();
+    // run with new scope
+    this.withNewScope(run);
+    await dev.sync();
 
-      for (let k = 0; k < repeat; ++k) {
-        const tstart = perf.now();
-        for (let i = 0; i < number; ++i) {
-          this.withNewScope(run);
-        }
-        await dev.sync();
-        const tend = perf.now();
-        results.push((tend - tstart) / number);
+    for (let k = 0; k < repeat; ++k) {
+      const tstart = perf.now();
+      for (let i = 0; i < number; ++i) {
+        this.withNewScope(run);
       }
-      return results;
+      await dev.sync();
+      const tend = perf.now();
+      results.push((tend - tstart) / number);
     }
+    return results;
+  }
 
   dispose(): void {
     // order matters
@@ -1131,9 +1210,9 @@ export class Instance implements Disposable {
    * @param func Input function.
    * @returns The converted function.
    */
-   toPackedFunc(func: Function): PackedFunc {
-      return this.toPackedFuncInternal(func, true);
-   }
+  toPackedFunc(func: Function): PackedFunc {
+    return this.toPackedFuncInternal(func, true);
+  }
 
   private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): 
PackedFunc {
     if (this.isPackedFunc(func)) return func as PackedFunc;
@@ -1142,6 +1221,200 @@ export class Instance implements Disposable {
     return ret;
   }
 
+   /**
+   * Setup a virtual machine module with given device.
+   *
+   * @param dev DLDevice the device.
+   * @returns The created virtual machime.
+   */
+  createVirtualMachine(dev: DLDevice): VirtualMachine {
+    const mod = this.ctx.detachFromCurrentScope(
+      this.systemLib().getFunction("vm_load_executable")()
+    );
+    return this.ctx.attachToCurrentScope(
+      new VirtualMachine(mod, dev)
+    );
+  }
+
+  //-----------------------------------------------
+  // Native NDArray Cache Support
+  //-----------------------------------------------
+  /**
+   * Register a call back for fetch progress.
+  *
+   * @param cb the fetch progress callback.
+   */
+  registerFetchProgressCallback(cb: FetchProgressCallback) {
+    this.fetchProgressCallback.push(cb);
+  }
+
+  /**
+   * Get NDArray from cache.
+   * @param name  The name of array.
+   * @returns  The result.
+   */
+  ndarrayCacheGet(name: string) : NDArray | undefined {
+    return this.ctx.arrayCacheGet(name);
+  }
+
+  /**
+   * Get NDArray from cache.
+   * @param name  The name of array.
+   * @returns  The result.
+   */
+  ndarrayCacheRemove(name: string) : NDArray | undefined {
+    return this.ctx.arrayCacheRemove(name);
+  }
+
+  /**
+   * Update the ndarray cache.
+   * @param name The name of the array.
+   * @param arr The content.
+   */
+  ndarrayCacheUpdate(name: string, arr: NDArray, override: boolean = false) {
+    this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0, 
"int32"));
+  }
+
+  /**
+   * Update the ndarray cache.
+   * @param name The name of the array.
+   * @param arr The content.
+   */
+  ndarrayCacheClear() {
+    this.ctx.arrayCacheClear();
+  }
+
+  /**
+   * Fetch NDArray cache from url.
+   *
+   * @param ndarrayCacheUrl The cache url.
+   * @param device The device to be fetched to.
+   */
+  async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) {
+    const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href;
+    var list;
+    try {
+
+      list = await (await fetch(jsonUrl)).json();
+    } catch(err) {
+      this.env.logger("Cannot fetch " + jsonUrl);
+    }
+    await this.fetchNDArrayCacheInternal(ndarrayCacheUrl, list as 
Array<NDArrayCacheEntry>, device);
+  }
+
+  /**
+   * Fetch list of NDArray into the NDArrayCache.
+   *
+   * @param ndarrayCacheUrl The cache url.
+   * @param list The list of array data.
+   * @param device The device to store the data to.
+   */
+  private async fetchNDArrayCacheInternal(ndarrayCacheUrl: string, list: 
Array<NDArrayCacheEntry>, device: DLDevice) {
+    const computeTotalBytes = (rec: NDArrayCacheEntry) => {
+
+      const dtype = this.toDLDataType(rec.dtype);
+      const size = rec.shape.reduce((a, b) => {
+        return a * b;
+      }, 1);
+      if (rec.format == "f32-to-bf16" && rec.dtype == "float32") {
+        return size * 2;
+      }
+      return size * dtype.bits * dtype.lanes / 8;
+    };
+    const perf = compact.getPerformance();
+    let tstart = perf.now();
+
+    let totalBytes = 0;
+    for (let i = 0; i < list.length; ++i) {
+      totalBytes += computeTotalBytes(list[i]);
+    };
+    let fetchedBytes = 0;
+    let timeElapsed = 0;
+
+    const reportCallback = (iter: number)=> {
+      // report
+      for (let j = 0; j < this.fetchProgressCallback.length; ++j) {
+        let text = "Fetching NDArray Cache[" + iter + "/" + list.length+ "]:";
+        text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB 
fetched "
+        text += "from " + Math.ceil(totalBytes / (1024 * 1024)).toString() + 
"MB, "
+        text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% 
completed, "
+        text += timeElapsed + " secs elapsed";
+        if (timeElapsed != 0){
+          text += ", speed=" + (fetchedBytes / timeElapsed / (1024 * 
1024)).toFixed(1) + " MB/sec";
+        }
+        this.fetchProgressCallback[j]({
+          fetchedBytes: fetchedBytes,
+          totalBytes: totalBytes,
+          timeElapsed: timeElapsed,
+          text: text
+        });
+      }
+    };
+
+    for (let j = 0; j < this.fetchProgressCallback.length; ++j) {
+      this.fetchProgressCallback[j]({
+        fetchedBytes: 0,
+        totalBytes: totalBytes,
+        timeElapsed: 0,
+        text: "Start to fetch " + ndarrayCacheUrl
+      });
+    }
+    const cache = await caches.open("tvmjs");
+
+    for (let i = 0; i < list.length; ++i) {
+      const rec = list[i];
+      reportCallback(i);
+      fetchedBytes += computeTotalBytes(rec);
+      const cpu_arr = this.withNewScope(() => {
+        return this.detachFromCurrentScope(
+          this.empty(rec.shape, rec.dtype, this.cpu())
+        )
+      });
+      const dataUrl = new URL(rec.dataPath, ndarrayCacheUrl).href;
+      const request = new Request(dataUrl);
+
+      let buffer;
+      try {
+        // use native cache
+        let result = await cache.match(request);
+        if (result === undefined) {
+          await cache.add(request);
+          result = await cache.match(request);
+        }
+        if (result == undefined) {
+          this.env.logger("Error: Cannot cache " + dataUrl + ", reloading will 
be slow");
+          result = await fetch(request);
+        }
+        buffer = await result.arrayBuffer();
+      } catch (err) {
+        this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
+        cpu_arr.dispose();
+        throw err;
+      }
+      // first sync copy to cpu.
+      this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(buffer), rec.format);
+      // then async stream into GPU if needed
+      if (device.deviceType == DeviceStrToEnum.cpu) {
+        this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
+        cpu_arr.dispose();
+      } else {
+        // allocate a gpu arr and async copy to it.
+        const gpu_arr = this.withNewScope(() => {
+          return this.detachFromCurrentScope(
+            this.empty(rec.shape, rec.dtype, device)
+          )
+        });
+        gpu_arr.copyFrom(cpu_arr);
+        await device.sync();
+        this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
+        cpu_arr.dispose();
+        gpu_arr.dispose();
+      }
+      timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
+    }
+    reportCallback(list.length);
+  }
+
   /**
    * Convert dtype to {@link DLDataType}
    *
diff --git a/web/tests/node/test_relax_vm.js b/web/tests/node/test_relax_vm.js
new file mode 100644
index 0000000000..ceb47aa014
--- /dev/null
+++ b/web/tests/node/test_relax_vm.js
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/* eslint-disable no-undef */
+// Load Emscripten Module, need to change path to root/lib
+const path = require("path");
+const fs = require("fs");
+const assert = require("assert");
+const tvmjs = require("../../dist");
+
+const wasmPath = tvmjs.wasmPath();
+const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js"));
+const wasmSource = fs.readFileSync(path.join(wasmPath, "test_relax.wasm"));
+
+const tvm = new tvmjs.Instance(
+  new WebAssembly.Module(wasmSource),
+  new EmccWASI()
+);
+
+
+function randomArray(length, max) {
+  return Array.apply(null, Array(length)).map(function () {
+    return Math.random() * max;
+  });
+}
+
+test("add one", () => {
+  tvm.beginScope();
+  // Load system library
+  const vm = tvm.createVirtualMachine(tvm.cpu());
+  // grab pre-loaded function
+  const fadd = vm.getFunction("main");
+
+  assert(tvm.isPackedFunc(fadd));
+  const n = 124;
+  const A = tvm.empty(n).copyFrom(randomArray(n, 1));
+  const B = tvm.empty(n).copyFrom(randomArray(n, 1));
+
+  // call the function.
+  const C = fadd(A, B);
+  const AA = A.toArray(); // retrieve values in js array
+  const BB = B.toArray(); // retrieve values in js array
+  const CC = C.toArray(); // retrieve values in js array
+  // verify
+  for (var i = 0; i < BB.length; ++i) {
+    assert(Math.abs(CC[i] - (AA[i] + BB[i])) < 1e-5);
+  }
+  tvm.endScope();
+  // assert auto release scope behavior
+  assert(vm.mod.getHandle(false) == 0);
+  assert(fadd._tvmPackedCell.getHandle(false) == 0);
+});
diff --git a/web/tests/python/prepare_test_libs.py 
b/web/tests/python/prepare_test_libs.py
index 5c1f7c68c4..a63e0655b4 100644
--- a/web/tests/python/prepare_test_libs.py
+++ b/web/tests/python/prepare_test_libs.py
@@ -18,12 +18,32 @@
 
 import tvm
 from tvm import te
-from tvm.contrib import emcc
+from tvm.contrib import tvmjs
 from tvm.relay.backend import Runtime
+from tvm import relax
+from tvm.script import relax as R
 import os
 
 
-def prepare_test_libs(base_path):
+def prepare_relax_lib(base_path):
+    pipeline = relax.get_pipeline()
+
+    @tvm.script.ir_module
+    class Mod:
+        @R.function
+        def main(x: R.Tensor(["n"], "float32"), y: R.Tensor(["n"], "float32")):
+            lv0 = R.add(x, y)
+            return lv0
+
+    target = tvm.target.Target("llvm -mtriple=wasm32-unknown-unknown-wasm")
+
+    mod = pipeline(Mod)
+    ex = relax.build(mod, target)
+    wasm_path = os.path.join(base_path, "test_relax.wasm")
+    ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
+
+
+def prepare_tir_lib(base_path):
     runtime = Runtime("cpp", {"system-lib": True})
     target = "llvm -mtriple=wasm32-unknown-unknown-wasm"
     if not tvm.runtime.enabled(target):
@@ -35,9 +55,11 @@ def prepare_test_libs(base_path):
     fadd = tvm.build(s, [A, B], target, runtime=runtime, name="add_one")
 
     wasm_path = os.path.join(base_path, "test_addone.wasm")
-    fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+    fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
 
 
 if __name__ == "__main__":
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-    prepare_test_libs(os.path.join(curr_path, "../../dist/wasm"))
+    base_path = os.path.join(curr_path, "../../dist/wasm")
+    prepare_tir_lib(base_path)
+    prepare_relax_lib(base_path)
diff --git a/web/tests/python/webgpu_rpc_test.py 
b/web/tests/python/relax_rpc_test.py
similarity index 51%
copy from web/tests/python/webgpu_rpc_test.py
copy to web/tests/python/relax_rpc_test.py
index 6e34a8a2b3..a347fe70b3 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/relax_rpc_test.py
@@ -14,47 +14,53 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Simple testcode to test Javascript RPC
-
-To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
-Connect javascript end to the websocket port and connect to the RPC.
-"""
+"""Test relax vm through rpc."""
 
 import tvm
-from tvm import te
-from tvm import rpc
-from tvm.contrib import utils, emcc
-from tvm.relay.backend import Runtime
 import numpy as np
+from tvm import rpc, relax
+from tvm.contrib import utils, tvmjs
+from tvm.script import relax as R
 
 proxy_host = "127.0.0.1"
 proxy_port = 9090
 
 
-def test_rpc():
-    if not tvm.runtime.enabled("rpc"):
-        return
-    # generate the wasm library
-    target = tvm.target.Target("webgpu", host="llvm 
-mtriple=wasm32-unknown-unknown-wasm")
-    runtime = Runtime("cpp", {"system-lib": True})
+def get_model():
+    pipeline = relax.get_pipeline()
 
-    n = 2048
-    A = te.placeholder((n,), name="A")
-    B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i)) + 1.0), name="B")
-    s = te.create_schedule(B.op)
+    @tvm.script.ir_module
+    class Mod:
+        @R.function
+        def main(x: R.Tensor([1024], "float32"), y: R.Tensor([1024], 
"float32")):
+            lv0 = R.add(x, y)
+            return lv0
 
-    num_thread = 2
-    xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
-    s[B].bind(xi, te.thread_axis("threadIdx.x"))
-    s[B].bind(xo, te.thread_axis("blockIdx.x"))
+    mod = pipeline(Mod)
+    sch = tvm.tir.Schedule(mod)
+    # manually transform loop
+    sch.work_on("add")
+    (i,) = sch.get_loops(block=sch.get_block("T_add"))
+    i0, i1 = sch.split(i, [None, 128])
+    sch.bind(i0, "blockIdx.x")
+    sch.bind(i1, "threadIdx.x")
+    return sch.mod
 
-    fadd = tvm.build(s, [A, B], target, runtime=runtime, name="addone")
-    temp = utils.tempdir()
 
-    wasm_path = temp.relpath("addone_gpu.wasm")
-    fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+def test_rpc():
+    if not tvm.runtime.enabled("rpc"):
+        return
+    n = 1024
+    dtype = "float32"
+    temp = utils.tempdir()
+    wasm_path = temp.relpath("relax.wasm")
+    target = tvm.target.Target("webgpu", host="llvm 
-mtriple=wasm32-unknown-unknown-wasm")
 
+    mod = get_model()
+    ex = relax.build(mod, target)
+    ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
     wasm_binary = open(wasm_path, "rb").read()
+
     remote = rpc.connect(
         proxy_host,
         proxy_port,
@@ -63,18 +69,17 @@ def test_rpc():
     )
 
     def check(remote):
-        # basic function checks.
         dev = remote.webgpu(0)
-        adata = np.random.uniform(size=n).astype(A.dtype)
+        # invoke the function
+        vm = relax.VirtualMachine(remote.system_lib(), device=dev)
+        adata = np.random.uniform(size=n).astype(dtype)
+        bdata = np.random.uniform(size=n).astype(dtype)
         a = tvm.nd.array(adata, dev)
-        b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
-
-        np.testing.assert_equal(a.numpy(), adata)
-        f1 = remote.system_lib()
-        addone = f1.get_function("addone")
-        addone(a, b)
-        np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), 
atol=1e-5, rtol=1e-5)
-        print("Test pass..")
+        b = tvm.nd.array(bdata, dev)
+        vm.set_input("main", a, b)
+        vm.invoke_stateful("main")
+        c = vm.get_outputs("main")
+        np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy())
 
     check(remote)
 
diff --git a/web/tests/python/webgpu_rpc_test.py 
b/web/tests/python/webgpu_rpc_test.py
index 6e34a8a2b3..986393e9d4 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/webgpu_rpc_test.py
@@ -23,7 +23,7 @@ Connect javascript end to the websocket port and connect to 
the RPC.
 import tvm
 from tvm import te
 from tvm import rpc
-from tvm.contrib import utils, emcc
+from tvm.contrib import utils, tvmjs
 from tvm.relay.backend import Runtime
 import numpy as np
 
@@ -52,7 +52,7 @@ def test_rpc():
     temp = utils.tempdir()
 
     wasm_path = temp.relpath("addone_gpu.wasm")
-    fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+    fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
 
     wasm_binary = open(wasm_path, "rb").read()
     remote = rpc.connect(
diff --git a/web/tests/python/websock_rpc_test.py 
b/web/tests/python/websock_rpc_test.py
index 7de5ee956e..19d5dc5748 100644
--- a/web/tests/python/websock_rpc_test.py
+++ b/web/tests/python/websock_rpc_test.py
@@ -23,7 +23,7 @@ Connect javascript end to the websocket port and connect to 
the RPC.
 import tvm
 from tvm import te
 from tvm import rpc
-from tvm.contrib import utils, emcc
+from tvm.contrib import utils, tvmjs
 from tvm.relay.backend import Runtime
 import numpy as np
 
@@ -48,7 +48,7 @@ def test_rpc():
     temp = utils.tempdir()
 
     wasm_path = temp.relpath("addone.wasm")
-    fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+    fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
 
     wasm_binary = open(wasm_path, "rb").read()
 


Reply via email to