This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 678d01dd4a [Unity][WEB] Relax vm on web runtime (#14131)
678d01dd4a is described below
commit 678d01dd4a4e75ef6186ce356bb1a20e584a7b24
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 25 13:22:10 2023 -0500
[Unity][WEB] Relax vm on web runtime (#14131)
This PR brings initial relax vm support on web runtime
---
include/tvm/runtime/relax_vm/vm.h | 4 +
python/tvm/contrib/tvmjs.py | 119 ++++++++
python/tvm/exec/rpc_proxy.py | 32 ++-
python/tvm/relax/vm_build.py | 14 +-
python/tvm/rpc/proxy.py | 21 +-
src/runtime/relax_vm/vm.cc | 11 +
web/.gitignore | 1 +
web/apps/browser/rpc_server.html | 65 ++++-
web/emcc/wasm_runtime.cc | 74 +++++
web/src/rpc_server.ts | 29 +-
web/src/runtime.ts | 315 +++++++++++++++++++--
web/tests/node/test_relax_vm.js | 67 +++++
web/tests/python/prepare_test_libs.py | 30 +-
.../{webgpu_rpc_test.py => relax_rpc_test.py} | 79 +++---
web/tests/python/webgpu_rpc_test.py | 4 +-
web/tests/python/websock_rpc_test.py | 4 +-
16 files changed, 780 insertions(+), 89 deletions(-)
diff --git a/include/tvm/runtime/relax_vm/vm.h
b/include/tvm/runtime/relax_vm/vm.h
index d39de74f2d..bd59106cc1 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -23,6 +23,10 @@
#ifndef TVM_RUNTIME_RELAX_VM_VM_H_
#define TVM_RUNTIME_RELAX_VM_VM_H_
+#ifndef TVM_RELAX_VM_ENABLE_PROFILER
+#define TVM_RELAX_VM_ENABLE_PROFILER 1
+#endif
+
#include <memory>
#include <string>
#include <vector>
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
new file mode 100644
index 0000000000..18cbf332c8
--- /dev/null
+++ b/python/tvm/contrib/tvmjs.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Namespace to store utilities for building web runtime."""
+# pylint: disable=unused-import
+import sys
+import os
+import json
+from typing import Mapping, Union
+
+import numpy as np
+
+import tvm
+from .emcc import create_tvmjs_wasm
+
+
+def _convert_f32_to_bf16(value):
+ cap = np.finfo("float32").max
+ assert -np.finfo("float32").max == np.finfo("float32").min
+ bf16_limit = ((np.array([cap.view("uint32")]) >> 16) <<
16).view("float32")[0]
+ # When the value is in [-bf16_limit, bf16_limit], round to nearest even.
+ # We can afford to do it in dumping phase to reduce overall rounding error.
+ #
+ # When the value is out of bound(usually mask values in attention), use
truncation
+ # so it is equivalent to clip to the limit values
+ data = value.view("uint32")
+ rounding_bias = np.where(
+ np.logical_and(value < bf16_limit, value > -bf16_limit),
+ ((data >> 16) & 1) + 0x7FFF,
+ np.zeros_like(data),
+ )
+ return ((data + rounding_bias) >> 16).astype("uint16")
+
+
+def dump_ndarray_cache(
+ params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+ cachedir: str,
+ encode_format="f32-to-bf16",
+):
+ """Dump parameters to NDArray cache.
+
+ Parameters
+ ----------
+ params: Mapping[str, tvm.runtime.NDArray],
+ The parameter dictionary
+
+ cachedir: str
+ The path to the cache
+
+ encode_format: {"f32-to-bf16", "raw"}
+ Encoding format.
+ """
+ records = []
+ total = len(params)
+ counter = 0
+ max_out_length = 0
+
+ if not os.path.exists(cachedir):
+ os.makedirs(cachedir)
+
+ f32_to_bf16_triggered = False
+
+ print("Start storing to cache %s" % cachedir)
+ for k, v in params.items():
+ fname = k + ".bin"
+ out_path = os.path.join(cachedir, fname)
+ shape = list(v.shape)
+
+ if not isinstance(v, np.ndarray):
+ v = v.numpy()
+
+ # convert fp32 to bf16
+ if encode_format == "f32-to-bf16" and v.dtype == "float32":
+ _convert_f32_to_bf16(v).tofile(out_path)
+ dtype = "bfloat16"
+ f32_to_bf16_triggered = True
+ else:
+ v.tofile(out_path)
+
+ dtype = str(v.dtype)
+ records.append(
+ {"name": k, "shape": shape, "dtype": dtype, "dataPath": fname,
"format": encode_format}
+ )
+ counter += 1
+ last_cmd = "[%04d/%04d] saving %s" % (counter, total, out_path)
+ flush = "\r" + (" " * max_out_length) + "\r"
+ max_out_length = max(len(last_cmd), max_out_length)
+ sys.stdout.write(flush + last_cmd)
+
+ nd_cache_json = os.path.join(cachedir, "ndarray-cache.json")
+ with open(nd_cache_json, "w") as outfile:
+ json.dump(records, outfile, indent=4)
+ print("\nAll finished, record saved to %s" % nd_cache_json)
+
+ if f32_to_bf16_triggered:
+ rec_bf16 = []
+ for item in records:
+ if item["dtype"] == "float32":
+ item["format"] = "raw"
+ item["dtype"] = "bfloat16"
+ rec_bf16.append(item)
+ b16_nd_cache_json = os.path.join(cachedir, "ndarray-cache-b16.json")
+ # also dump a file that contains bf16
+ with open(b16_nd_cache_json, "w") as outfile:
+ json.dump(rec_bf16, outfile, indent=4)
+ print("Also saved a bf16 record to %s" % b16_nd_cache_json)
diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index 7eae4fe174..d340750785 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -19,6 +19,7 @@
import logging
import argparse
import os
+import glob
from tvm.rpc.proxy import Proxy
@@ -28,16 +29,29 @@ def find_example_resource():
base_path = os.path.abspath(os.path.join(curr_path, "..", "..", ".."))
index_page = os.path.join(base_path, "web", "apps", "browser",
"rpc_server.html")
resource_files = [
- os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"),
- os.path.join(base_path, "web", "dist", "wasm",
"tvmjs_runtime.wasi.js"),
+ ("/", os.path.join(base_path, "web", "dist", "tvmjs.bundle.js")),
+ ("/", os.path.join(base_path, "web", "dist", "wasm",
"tvmjs_runtime.wasi.js")),
+ ("/", index_page),
]
- resource_base = os.path.join(base_path, "web", "dist", "www")
- if os.path.isdir(resource_base):
- for fname in os.listdir(resource_base):
- full_name = os.path.join(resource_base, fname)
- if os.path.isfile(full_name):
- resource_files.append(full_name)
- for fname in [index_page] + resource_files:
+ allow_format = ("json", "bin", "js", "wasm")
+
+ # recursively apend things in www, up to two levels
+ resource_bases = [
+ os.path.join(base_path, "web", "dist", "www"),
+ os.path.join(base_path, "web", ".ndarray_cache"),
+ ]
+ for base in resource_bases:
+ if not os.path.isdir(base):
+ continue
+ for full_name in glob.glob("%s/**" % base, recursive=True):
+ fname = os.path.relpath(full_name, base)
+ dirname = os.path.dirname(fname)
+ fmt = fname.rsplit(".", 1)[-1]
+ if os.path.isfile(full_name) and fmt in allow_format:
+ resource_files.append((dirname, full_name))
+
+ for item in resource_files:
+ fname = item[-1]
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
return index_page, resource_files
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 35fc65bdc6..0586bf9217 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -180,6 +180,18 @@ def _vmcodegen(
raise ValueError("Unknown exec_mode %s" % exec_mode)
+def _autodetect_system_lib_req(target: tvm.target.Target):
+ """Automatically detect system lib requirement"""
+ host = target if target.host is None else target.host
+ system_lib = False
+ if "wasm" in host.attrs.get("mtriple", ""):
+ system_lib = True
+ if system_lib:
+ # use packed-func to avoid relay dep.
+ return tvm.get_global_func("relay.backend.CreateRuntime")("cpp",
{"system-lib": system_lib})
+ return None
+
+
def _vmlink(
builder: "relax.ExecBuilder",
target: Union[str, tvm.target.Target],
@@ -224,7 +236,7 @@ def _vmlink(
ext_libs = []
lib = None
if tir_mod is not None:
- lib = tvm.build(tir_mod, target=target)
+ lib = tvm.build(tir_mod, target=target,
runtime=_autodetect_system_lib_req(target))
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params))
# type: ignore
diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py
index d7027c88a4..59af53d4e1 100644
--- a/python/tvm/rpc/proxy.py
+++ b/python/tvm/rpc/proxy.py
@@ -203,11 +203,20 @@ class WebSocketHandler(websocket.WebSocketHandler,
ForwardHandler):
self.close()
+MIME_MAP = {
+ "js": "application/javascript",
+ "wasm": "application/wasm",
+ "json": "application/json",
+}
+
+
class RequestHandler(tornado.web.RequestHandler):
"""Handles html request."""
def __init__(self, *args, **kwargs):
file_path = kwargs.pop("file_path")
+ self.format = file_path.split(".")[-1]
+
if file_path.endswith("html"):
self.page = open(file_path).read()
web_port = kwargs.pop("rpc_web_port", None)
@@ -217,12 +226,15 @@ class RequestHandler(tornado.web.RequestHandler):
)
else:
self.page = open(file_path, "rb").read()
+
super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _):
pass
def get(self, *args, **kwargs):
+ if self.format in MIME_MAP:
+ self.set_header("Content-Type", MIME_MAP[self.format])
self.write(self.page)
@@ -254,9 +266,14 @@ class ProxyServerHandler(object):
)
logging.info("Serving RPC index html page at
http://localhost:%d", web_port)
resource_files = resource_files if resource_files else []
- for fname in resource_files:
+ for item in resource_files:
+ prefix, fname = item
+ if not prefix.endswith("/"):
+ prefix += "/"
+ if not prefix.startswith("/"):
+ prefix = "/" + prefix
basename = os.path.basename(fname)
- pair = (r"/%s" % basename, RequestHandler, {"file_path":
fname})
+ pair = (r"%s%s" % (prefix, basename), RequestHandler,
{"file_path": fname})
handlers.append(pair)
logging.info(pair)
self.app = tornado.web.Application(handlers)
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 3b952c1ff5..8679b2a793 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -827,6 +827,11 @@ void VirtualMachineImpl::RunLoop() {
ObjectPtr<VirtualMachine> VirtualMachine::Create() { return
make_object<VirtualMachineImpl>(); }
+//----------------------------------------------------------------
+// Profiler can be optionally disabled via a macro to reduce dep.
+//----------------------------------------------------------------
+#if TVM_RELAX_VM_ENABLE_PROFILER
+
/*!
* \brief An extension of VirtualMachineImpl to support per-op profiling
* It overrides RunInstrCall to add instrumentations around it.
@@ -927,6 +932,12 @@ ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler()
{
return make_object<VirtualMachineProfiler>();
}
+#else
+ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() {
+ LOG(FATAL) << "Profiler support is disabled";
+ return nullptr;
+}
+#endif // TVM_RELAX_VM_ENABLE_PROFILER
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
diff --git a/web/.gitignore b/web/.gitignore
index 1f7cc0916a..69bf96a8a7 100644
--- a/web/.gitignore
+++ b/web/.gitignore
@@ -4,3 +4,4 @@ out
node_modules
build
debug
+.ndarray_cache
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
index 6d353e29b0..8fa50272b2 100644
--- a/web/apps/browser/rpc_server.html
+++ b/web/apps/browser/rpc_server.html
@@ -15,38 +15,71 @@
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
+ <!DOCTYPE html>
- <head>
+ <head lang="en-US"></head>
<title>TVM RPC Test Page</title>
</head>
<script src="tvmjs_runtime.wasi.js"></script>
<script src="tvmjs.bundle.js"></script>
<script>
-
function customLog(message) {
console.log(message);
const d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
+
function clearLog() {
const node = document.getElementById("log");
while (node.hasChildNodes()) {
node.removeChild(node.lastChild);
}
}
+
+ function fetchProgressCallback(report) {
+ document.getElementById("fetch-text").innerHTML = report.text;
+ document.getElementById("fetch-progress").value = (report.fetchedBytes /
report.totalBytes) * 100;
+ }
+
function connectRPC() {
- const proxyUrl = document.getElementById("proxyURL").value;
+ const proxyUrl = document.getElementById("proxyUrl").value;
const key = document.getElementById("proxyKey").value;
+ const ndarrayCacheName = document.getElementById("cache-select").value;
+ let ndarrayCacheUrl = new URL(ndarrayCacheName + "/", document.URL).href;
+ let ndarrayCacheDevice =
document.getElementById("ndarrayCacheDevice").value;
+
+ if (ndarrayCacheName == "none" || ndarrayCacheName === undefined) {
+ ndarrayCacheUrl = "";
+ }
+
// only works for once.
const getImports = () => {
return new EmccWASI();
};
- new tvmjs.RPCServer(proxyUrl, key, getImports, customLog);
+ new tvmjs.RPCServer(
+ proxyUrl, key, getImports, customLog,
+ ndarrayCacheUrl, ndarrayCacheDevice, fetchProgressCallback);
+ }
+
+ async function loadCacheOption() {
+ const select = document.getElementById("cache-select");
+ try {
+ const list = await (await fetch("/cache-list.json")).json()
+ for (let i = 0; i < list.length; ++i) {
+ const option = document.createElement("option");
+ option.text = list[i];
+ option.value = list[i];
+ select.add(option);
+ }
+ if (list.length != 0) {
+ select.value = list[0];
+ }
+ } catch (err) {}
}
</script>
- <body>
+ <body onload="loadCacheOption()">
<h1>TVM WebSocket RPC Server</h1>
To use this page
<ul>
@@ -59,20 +92,34 @@
</ul>
<h2>Options</h2>
- Proxy URL<input
- name="proxyurl"
- id="proxyURL"
+ Proxy URL <input
+ name="proxyrl"
+ id="proxyUrl"
type="text"
value="ws://localhost:8888/ws"
/><br />
- RPC Server Key<input
+ RPC Server Key <input
name="serverkey"
id="proxyKey"
type="text"
value="wasm"
/><br />
+ NDArrayCache -
+ <select name="cache-name" id="cache-select">
+ <option value="none">none</option>
+ </select>
+ CacheDevice -
+ <select name="cache-device" id="ndarrayCacheDevice">
+ <option value="webgpu">webgpu</option>
+ <option value="cpu">cpu</option>
+ </select>
+ <br />
<button onclick="connectRPC()">Connect To Proxy</button>
<button onclick="clearLog()">Clear Log</button>
+ <div id="progress">
+ <label id="fetch-text"></div>
+ <progress id="fetch-progress" max="100" value="100"> </progress>
+ </div>
<div id="log"></div>
<canvas id="canvas" width="224" height="224"></canvas>
</body>
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 00d2a8c579..c90b917c5c 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -26,6 +26,7 @@
#define TVM_LOG_STACK_TRACE 0
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
+
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
#include <tvm/runtime/c_runtime_api.h>
@@ -51,6 +52,12 @@
#include "src/runtime/rpc/rpc_session.cc"
#include "src/runtime/system_library.cc"
#include "src/runtime/workspace_pool.cc"
+// relax setup
+#include "src/runtime/relax_vm/builtin.cc"
+#include "src/runtime/relax_vm/bytecode.cc"
+#include "src/runtime/relax_vm/executable.cc"
+#include "src/runtime/relax_vm/memory_manager.cc"
+#include "src/runtime/relax_vm/vm.cc"
// --- Implementations of backend and wasm runtime API. ---
@@ -111,5 +118,72 @@
TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
// and get another value.
*ret = (obj.use_count() - 1);
});
+
+/*!
+ * A NDArray cache to store pre-loaded arrays in the system.
+ */
+class NDArrayCache {
+ public:
+ static NDArrayCache* Global() {
+ static NDArrayCache* inst = new NDArrayCache();
+ return inst;
+ }
+
+ static void Update(String name, NDArray arr, bool override) {
+ NDArrayCache* pool = Global();
+ if (!override) {
+ ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already
exists in the cache";
+ }
+ pool->pool_.Set(name, arr);
+ }
+
+ static Optional<NDArray> Get(String name) {
+ NDArrayCache* pool = Global();
+ auto it = pool->pool_.find(name);
+ if (it != pool->pool_.end()) {
+ return (*it).second;
+ } else {
+ return NullOpt;
+ }
+ }
+
+ static void Remove(String name) {
+ NDArrayCache* pool = Global();
+ pool->pool_.erase(name);
+ }
+
+ static void Clear() { Global()->pool_.clear(); }
+
+ private:
+ Map<String, NDArray> pool_;
+};
+
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.get").set_body_typed(NDArrayCache::Get);
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.update").set_body_typed(NDArrayCache::Update);
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove);
+TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear);
+
+void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string
format) {
+ if (format == "f32-to-bf16") {
+ std::vector<uint16_t> buffer(bytes.length() / 2);
+ std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2);
+ // decode bf16 to f32
+ const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(buffer.data());
+ uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
+ ICHECK(cpu_arr.IsContiguous());
+ size_t size = 1;
+ for (int i = 0; i < cpu_arr->ndim; ++i) {
+ size *= cpu_arr->shape[i];
+ }
+ ICHECK_EQ(size, bytes.length() / 2);
+ for (size_t i = 0; i < size; ++i) {
+ data[i] = static_cast<uint32_t>(bf16[i]) << 16;
+ }
+ } else {
+ cpu_arr.CopyFromBytes(bytes.data(), bytes.length());
+ }
+}
+
+TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
} // namespace runtime
} // namespace tvm
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index e37d1838d6..4dd7228d3c 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -82,6 +82,9 @@ export class RPCServer {
state: RPCServerState = RPCServerState.InitHeader;
logger: (msg: string) => void;
getImports: () => Record<string, unknown>;
+ private ndarrayCacheUrl: string;
+ private ndarrayCacheDevice: string;
+ private fetchProgressCallback?: runtime.FetchProgressCallback;
private pendingSend: Promise<void> = Promise.resolve();
private name: string;
private inst?: runtime.Instance = undefined;
@@ -98,13 +101,19 @@ export class RPCServer {
url: string,
key: string,
getImports: () => Record<string, unknown>,
- logger: (msg: string) => void = console.log
+ logger: (msg: string) => void = console.log,
+ ndarrayCacheUrl: string = "",
+ ndarrayCacheDevice: string = "cpu",
+ fetchProgressCallback: runtime.FetchProgressCallback | undefined =
undefined
) {
this.url = url;
this.key = key;
this.name = "WebSocketRPCServer[" + this.key + "]: ";
this.getImports = getImports;
this.logger = logger;
+ this.ndarrayCacheUrl = ndarrayCacheUrl;
+ this.ndarrayCacheDevice = ndarrayCacheDevice;
+ this.fetchProgressCallback = fetchProgressCallback;
this.checkLittleEndian();
this.socket = compact.createWebSocket(url);
@@ -132,7 +141,9 @@ export class RPCServer {
if (this.state == RPCServerState.ReceivePacketHeader) {
this.log("Closing the server in clean state");
this.log("Automatic reconnecting..");
- new RPCServer(this.url, this.key, this.getImports, this.logger);
+ new RPCServer(
+ this.url, this.key, this.getImports, this.logger,
+ this.ndarrayCacheUrl, this.ndarrayCacheDevice,
this.fetchProgressCallback);
} else {
this.log("Closing the server, final state=" + this.state);
}
@@ -272,6 +283,20 @@ export class RPCServer {
// begin scope to allow handling of objects
// the object should stay alive during all sessions.
this.inst.beginScope();
+ if (this.fetchProgressCallback !== undefined) {
+ this.inst.registerFetchProgressCallback(this.fetchProgressCallback);
+ }
+
+ if (this.ndarrayCacheUrl.length != 0) {
+ if (this.ndarrayCacheDevice == "cpu") {
+ await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl,
this.inst.cpu());
+ } else {
+ assert(this.ndarrayCacheDevice == "webgpu");
+ await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl,
this.inst.webgpu());
+ }
+ }
+
+ assert(this.inst !== undefined);
const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer");
const messageHandler = fcreate(
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index a24459ca29..463532762e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -29,7 +29,6 @@ import { WebGPUContext } from "./webgpu";
import * as compact from "./compact";
import * as ctypes from "./ctypes";
-import { tsImportEqualsDeclaration } from "@babel/types";
/**
* Type for PackedFunc inthe TVMRuntime.
@@ -144,6 +143,11 @@ class RuntimeContext implements Disposable {
arrayGetSize : PackedFunc;
arrayMake : PackedFunc;
getSysLib: PackedFunc;
+ arrayCacheGet: PackedFunc;
+ arrayCacheUpdate: PackedFunc;
+ arrayCacheRemove: PackedFunc;
+ arrayCacheClear: PackedFunc;
+ arrayDecodeStorage: PackedFunc;
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
@@ -152,12 +156,25 @@ class RuntimeContext implements Disposable {
this.arrayGetSize = getGlobalFunc("runtime.ArraySize");
this.arrayMake = getGlobalFunc("runtime.Array");
this.getSysLib = getGlobalFunc("runtime.SystemLib");
+ this.arrayCacheGet = getGlobalFunc("tvmjs.ndarray_cache.get");
+ this.arrayCacheRemove = getGlobalFunc("tvmjs.ndarray_cache.remove");
+ this.arrayCacheUpdate = getGlobalFunc("tvmjs.ndarray_cache.update");
+ this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear");
+ this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
+
}
dispose(): void {
+ // call array cache clear to clear all cached items
+ this.arrayCacheClear();
this.arrayGetItem.dispose();
this.arrayGetSize.dispose();
this.arrayMake.dispose();
+ this.arrayCacheGet.dispose();
+ this.arrayCacheRemove.dispose();
+ this.arrayCacheUpdate.dispose();
+ this.arrayCacheClear.dispose();
+ this.arrayDecodeStorage.dispose();
}
beginScope() : void {
@@ -522,6 +539,9 @@ export class NDArray implements Disposable {
* @returns this
*/
copyFromRawBytes(data: Uint8Array): this {
+ if (this.device.deviceType != DeviceStrToEnum.cpu) {
+ throw new Error("Can only sync copy CPU array, use
cpu_arr.copyfrom(gpu_arr) then sync instead.");
+ }
const size = this.shape.reduce((a, b) => {
return a * b;
}, 1);
@@ -552,7 +572,7 @@ export class NDArray implements Disposable {
*/
toRawBytes(): Uint8Array {
if (this.device.deviceType != DeviceStrToEnum.cpu) {
- throw new Error("Can only synchronize copy for GPU array, use copyfrom
instead.");
+ throw new Error("Can only sync copy CPU array, use
cpu_arr.copyfrom(gpu_arr) then sync instead.");
}
const size = this.shape.reduce((a, b) => {
return a * b;
@@ -806,12 +826,70 @@ export class TVMArray extends TVMObject {
}
}
+export const enum VMAllocatorKind {
+ NAIVE_ALLOCATOR = 1,
+ POOLED_ALLOCATOR = 2,
+}
+
+/**
+ * VirtualMachine Executor.
+ *
+ * This is a thin wrapper of the underlying TVM module.
+ * you can also directly call set_input, run, and get_output
+ * of underlying module functions
+ */
+export class VirtualMachine implements Disposable {
+ private mod: Module;
+ /**
+ * Constructor
+ * @param mod The underlying module, need to be detached.
+ * @param device The main device ro run VM on.
+ */
+ constructor(mod: Module, device: DLDevice) {
+ this.mod = mod;
+ this.mod.getFunction("vm_initialization")(
+ new Scalar(device.deviceType, "int"),
+ new Scalar(device.deviceId, "int"),
+ new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int")
+ );
+ }
+
+ dispose(): void {
+ this.mod.dispose();
+ }
+ /**
+ * Get a function in the VM module.
+ * @param name The name of the function.
+ * @returns The result function.
+ */
+ getFunction(name: string): PackedFunc {
+ return this.mod.getFunction(name);
+ }
+}
+
/** Code used as the first argument of the async callback. */
const enum AyncCallbackCode {
kReturn = 4,
kException = 5,
}
+export interface NDArrayCacheEntry {
+ name: string;
+ shape: Array<number>;
+ dtype: string;
+ format: "f32-to-bf16" | "raw";
+ dataPath: string;
+}
+
+export interface FetchProgressReport {
+ fetchedBytes: number;
+ totalBytes: number;
+ timeElapsed: number;
+ text: string;
+}
+
+export type FetchProgressCallback = (report: FetchProgressReport) => void;
+
/**
* TVM runtime instance.
*
@@ -836,6 +914,7 @@ export class Instance implements Disposable {
private env: Environment;
private objFactory: Map<number, FObjectConstructor>;
private ctx: RuntimeContext;
+ private fetchProgressCallback: Array<FetchProgressCallback> = [];
/**
* Internal function(registered by the runtime)
@@ -898,26 +977,26 @@ export class Instance implements Disposable {
* @number The number of times to compute the average.
* @repeat The number of times to repeat the run.
*/
- async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4):
Promise<number[]> {
- // Skip first run as it can involve GPU warmup and module loading time.
- const perf = compact.getPerformance();
- const results = [];
+ async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4):
Promise<number[]> {
+ // Skip first run as it can involve GPU warmup and module loading time.
+ const perf = compact.getPerformance();
+ const results = [];
- // run with new scope
- this.withNewScope(run);
- await dev.sync();
+ // run with new scope
+ this.withNewScope(run);
+ await dev.sync();
- for (let k = 0; k < repeat; ++k) {
- const tstart = perf.now();
- for (let i = 0; i < number; ++i) {
- this.withNewScope(run);
- }
- await dev.sync();
- const tend = perf.now();
- results.push((tend - tstart) / number);
+ for (let k = 0; k < repeat; ++k) {
+ const tstart = perf.now();
+ for (let i = 0; i < number; ++i) {
+ this.withNewScope(run);
}
- return results;
+ await dev.sync();
+ const tend = perf.now();
+ results.push((tend - tstart) / number);
}
+ return results;
+ }
dispose(): void {
// order matters
@@ -1131,9 +1210,9 @@ export class Instance implements Disposable {
* @param func Input function.
* @returns The converted function.
*/
- toPackedFunc(func: Function): PackedFunc {
- return this.toPackedFuncInternal(func, true);
- }
+ toPackedFunc(func: Function): PackedFunc {
+ return this.toPackedFuncInternal(func, true);
+ }
private toPackedFuncInternal(func: Function, autoAttachToScope: boolean):
PackedFunc {
if (this.isPackedFunc(func)) return func as PackedFunc;
@@ -1142,6 +1221,200 @@ export class Instance implements Disposable {
return ret;
}
+ /**
+ * Setup a virtual machine module with given device.
+ *
+ * @param dev DLDevice the device.
+ * @returns The created virtual machime.
+ */
+ createVirtualMachine(dev: DLDevice): VirtualMachine {
+ const mod = this.ctx.detachFromCurrentScope(
+ this.systemLib().getFunction("vm_load_executable")()
+ );
+ return this.ctx.attachToCurrentScope(
+ new VirtualMachine(mod, dev)
+ );
+ }
+
+ //-----------------------------------------------
+ // Native NDArray Cache Support
+ //-----------------------------------------------
+ /**
+ * Register a call back for fetch progress.
+ *
+ * @param cb the fetch progress callback.
+ */
+ registerFetchProgressCallback(cb: FetchProgressCallback) {
+ this.fetchProgressCallback.push(cb);
+ }
+
+ /**
+ * Get NDArray from cache.
+ * @param name The name of array.
+ * @returns The result.
+ */
+ ndarrayCacheGet(name: string) : NDArray | undefined {
+ return this.ctx.arrayCacheGet(name);
+ }
+
+ /**
+ * Get NDArray from cache.
+ * @param name The name of array.
+ * @returns The result.
+ */
+ ndarrayCacheRemove(name: string) : NDArray | undefined {
+ return this.ctx.arrayCacheRemove(name);
+ }
+
+ /**
+ * Update the ndarray cache.
+ * @param name The name of the array.
+ * @param arr The content.
+ */
+ ndarrayCacheUpdate(name: string, arr: NDArray, override: boolean = false) {
+ this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0,
"int32"));
+ }
+
+ /**
+ * Update the ndarray cache.
+ * @param name The name of the array.
+ * @param arr The content.
+ */
+ ndarrayCacheClear() {
+ this.ctx.arrayCacheClear();
+ }
+
+ /**
+ * Fetch NDArray cache from url.
+ *
+ * @param ndarrayCacheUrl The cache url.
+ * @param device The device to be fetched to.
+ */
+ async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) {
+ const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href;
+ var list;
+ try {
+
+ list = await (await fetch(jsonUrl)).json();
+ } catch(err) {
+ this.env.logger("Cannot fetch " + jsonUrl);
+ }
+ await this.fetchNDArrayCacheInternal(ndarrayCacheUrl, list as
Array<NDArrayCacheEntry>, device);
+ }
+
+ /**
+ * Fetch list of NDArray into the NDArrayCache.
+ *
+ * @param ndarrayCacheUrl The cache url.
+ * @param list The list of array data.
+ * @param device The device to store the data to.
+ */
+ private async fetchNDArrayCacheInternal(ndarrayCacheUrl: string, list:
Array<NDArrayCacheEntry>, device: DLDevice) {
+ const computeTotalBytes = (rec: NDArrayCacheEntry) => {
+
+ const dtype = this.toDLDataType(rec.dtype);
+ const size = rec.shape.reduce((a, b) => {
+ return a * b;
+ }, 1);
+ if (rec.format == "f32-to-bf16" && rec.dtype == "float32") {
+ return size * 2;
+ }
+ return size * dtype.bits * dtype.lanes / 8;
+ };
+ const perf = compact.getPerformance();
+ let tstart = perf.now();
+
+ let totalBytes = 0;
+ for (let i = 0; i < list.length; ++i) {
+ totalBytes += computeTotalBytes(list[i]);
+ };
+ let fetchedBytes = 0;
+ let timeElapsed = 0;
+
+ const reportCallback = (iter: number)=> {
+ // report
+ for (let j = 0; j < this.fetchProgressCallback.length; ++j) {
+ let text = "Fetching NDArray Cache[" + iter + "/" + list.length+ "]:";
+ text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
fetched "
+ text += "from " + Math.ceil(totalBytes / (1024 * 1024)).toString() +
"MB, "
+ text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
completed, "
+ text += timeElapsed + " secs elapsed";
+ if (timeElapsed != 0){
+ text += ", speed=" + (fetchedBytes / timeElapsed / (1024 *
1024)).toFixed(1) + " MB/sec";
+ }
+ this.fetchProgressCallback[j]({
+ fetchedBytes: fetchedBytes,
+ totalBytes: totalBytes,
+ timeElapsed: timeElapsed,
+ text: text
+ });
+ }
+ };
+
+ for (let j = 0; j < this.fetchProgressCallback.length; ++j) {
+ this.fetchProgressCallback[j]({
+ fetchedBytes: 0,
+ totalBytes: totalBytes,
+ timeElapsed: 0,
+ text: "Start to fetch " + ndarrayCacheUrl
+ });
+ }
+ const cache = await caches.open("tvmjs");
+
+ for (let i = 0; i < list.length; ++i) {
+ const rec = list[i];
+ reportCallback(i);
+ fetchedBytes += computeTotalBytes(rec);
+ const cpu_arr = this.withNewScope(() => {
+ return this.detachFromCurrentScope(
+ this.empty(rec.shape, rec.dtype, this.cpu())
+ )
+ });
+ const dataUrl = new URL(rec.dataPath, ndarrayCacheUrl).href;
+ const request = new Request(dataUrl);
+
+ let buffer;
+ try {
+ // use native cache
+ let result = await cache.match(request);
+ if (result === undefined) {
+ await cache.add(request);
+ result = await cache.match(request);
+ }
+ if (result == undefined) {
+ this.env.logger("Error: Cannot cache " + dataUrl + ", reloading will
be slow");
+ result = await fetch(request);
+ }
+ buffer = await result.arrayBuffer();
+ } catch (err) {
+ this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
+ cpu_arr.dispose();
+ throw err;
+ }
+ // first sync copy to cpu.
+ this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(buffer), rec.format);
+ // then async stream into GPU if needed
+ if (device.deviceType == DeviceStrToEnum.cpu) {
+ this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
+ cpu_arr.dispose();
+ } else {
+ // allocate a gpu arr and async copy to it.
+ const gpu_arr = this.withNewScope(() => {
+ return this.detachFromCurrentScope(
+ this.empty(rec.shape, rec.dtype, device)
+ )
+ });
+ gpu_arr.copyFrom(cpu_arr);
+ await device.sync();
+ this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
+ cpu_arr.dispose();
+ gpu_arr.dispose();
+ }
+ timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
+ }
+ reportCallback(list.length);
+ }
+
/**
* Convert dtype to {@link DLDataType}
*
diff --git a/web/tests/node/test_relax_vm.js b/web/tests/node/test_relax_vm.js
new file mode 100644
index 0000000000..ceb47aa014
--- /dev/null
+++ b/web/tests/node/test_relax_vm.js
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/* eslint-disable no-undef */
+// Load Emscripten Module, need to change path to root/lib
+const path = require("path");
+const fs = require("fs");
+const assert = require("assert");
+const tvmjs = require("../../dist");
+
+const wasmPath = tvmjs.wasmPath();
+const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js"));
+const wasmSource = fs.readFileSync(path.join(wasmPath, "test_relax.wasm"));
+
+const tvm = new tvmjs.Instance(
+ new WebAssembly.Module(wasmSource),
+ new EmccWASI()
+);
+
+
+function randomArray(length, max) {
+ return Array.apply(null, Array(length)).map(function () {
+ return Math.random() * max;
+ });
+}
+
+test("add one", () => {
+ tvm.beginScope();
+ // Load system library
+ const vm = tvm.createVirtualMachine(tvm.cpu());
+ // grab pre-loaded function
+ const fadd = vm.getFunction("main");
+
+ assert(tvm.isPackedFunc(fadd));
+ const n = 124;
+ const A = tvm.empty(n).copyFrom(randomArray(n, 1));
+ const B = tvm.empty(n).copyFrom(randomArray(n, 1));
+
+ // call the function.
+ const C = fadd(A, B);
+ const AA = A.toArray(); // retrieve values in js array
+ const BB = B.toArray(); // retrieve values in js array
+ const CC = C.toArray(); // retrieve values in js array
+ // verify
+ for (var i = 0; i < BB.length; ++i) {
+ assert(Math.abs(CC[i] - (AA[i] + BB[i])) < 1e-5);
+ }
+ tvm.endScope();
+ // assert auto release scope behavior
+ assert(vm.mod.getHandle(false) == 0);
+ assert(fadd._tvmPackedCell.getHandle(false) == 0);
+});
diff --git a/web/tests/python/prepare_test_libs.py
b/web/tests/python/prepare_test_libs.py
index 5c1f7c68c4..a63e0655b4 100644
--- a/web/tests/python/prepare_test_libs.py
+++ b/web/tests/python/prepare_test_libs.py
@@ -18,12 +18,32 @@
import tvm
from tvm import te
-from tvm.contrib import emcc
+from tvm.contrib import tvmjs
from tvm.relay.backend import Runtime
+from tvm import relax
+from tvm.script import relax as R
import os
-def prepare_test_libs(base_path):
+def prepare_relax_lib(base_path):
+ pipeline = relax.get_pipeline()
+
+ @tvm.script.ir_module
+ class Mod:
+ @R.function
+ def main(x: R.Tensor(["n"], "float32"), y: R.Tensor(["n"], "float32")):
+ lv0 = R.add(x, y)
+ return lv0
+
+ target = tvm.target.Target("llvm -mtriple=wasm32-unknown-unknown-wasm")
+
+ mod = pipeline(Mod)
+ ex = relax.build(mod, target)
+ wasm_path = os.path.join(base_path, "test_relax.wasm")
+ ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
+
+
+def prepare_tir_lib(base_path):
runtime = Runtime("cpp", {"system-lib": True})
target = "llvm -mtriple=wasm32-unknown-unknown-wasm"
if not tvm.runtime.enabled(target):
@@ -35,9 +55,11 @@ def prepare_test_libs(base_path):
fadd = tvm.build(s, [A, B], target, runtime=runtime, name="add_one")
wasm_path = os.path.join(base_path, "test_addone.wasm")
- fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+ fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
if __name__ == "__main__":
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
- prepare_test_libs(os.path.join(curr_path, "../../dist/wasm"))
+ base_path = os.path.join(curr_path, "../../dist/wasm")
+ prepare_tir_lib(base_path)
+ prepare_relax_lib(base_path)
diff --git a/web/tests/python/webgpu_rpc_test.py
b/web/tests/python/relax_rpc_test.py
similarity index 51%
copy from web/tests/python/webgpu_rpc_test.py
copy to web/tests/python/relax_rpc_test.py
index 6e34a8a2b3..a347fe70b3 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/relax_rpc_test.py
@@ -14,47 +14,53 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Simple testcode to test Javascript RPC
-
-To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
-Connect javascript end to the websocket port and connect to the RPC.
-"""
+"""Test relax vm through rpc."""
import tvm
-from tvm import te
-from tvm import rpc
-from tvm.contrib import utils, emcc
-from tvm.relay.backend import Runtime
import numpy as np
+from tvm import rpc, relax
+from tvm.contrib import utils, tvmjs
+from tvm.script import relax as R
proxy_host = "127.0.0.1"
proxy_port = 9090
-def test_rpc():
- if not tvm.runtime.enabled("rpc"):
- return
- # generate the wasm library
- target = tvm.target.Target("webgpu", host="llvm
-mtriple=wasm32-unknown-unknown-wasm")
- runtime = Runtime("cpp", {"system-lib": True})
+def get_model():
+ pipeline = relax.get_pipeline()
- n = 2048
- A = te.placeholder((n,), name="A")
- B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i)) + 1.0), name="B")
- s = te.create_schedule(B.op)
+ @tvm.script.ir_module
+ class Mod:
+ @R.function
+ def main(x: R.Tensor([1024], "float32"), y: R.Tensor([1024],
"float32")):
+ lv0 = R.add(x, y)
+ return lv0
- num_thread = 2
- xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
- s[B].bind(xi, te.thread_axis("threadIdx.x"))
- s[B].bind(xo, te.thread_axis("blockIdx.x"))
+ mod = pipeline(Mod)
+ sch = tvm.tir.Schedule(mod)
+ # manually transform loop
+ sch.work_on("add")
+ (i,) = sch.get_loops(block=sch.get_block("T_add"))
+ i0, i1 = sch.split(i, [None, 128])
+ sch.bind(i0, "blockIdx.x")
+ sch.bind(i1, "threadIdx.x")
+ return sch.mod
- fadd = tvm.build(s, [A, B], target, runtime=runtime, name="addone")
- temp = utils.tempdir()
- wasm_path = temp.relpath("addone_gpu.wasm")
- fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+def test_rpc():
+ if not tvm.runtime.enabled("rpc"):
+ return
+ n = 1024
+ dtype = "float32"
+ temp = utils.tempdir()
+ wasm_path = temp.relpath("relax.wasm")
+ target = tvm.target.Target("webgpu", host="llvm
-mtriple=wasm32-unknown-unknown-wasm")
+ mod = get_model()
+ ex = relax.build(mod, target)
+ ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
wasm_binary = open(wasm_path, "rb").read()
+
remote = rpc.connect(
proxy_host,
proxy_port,
@@ -63,18 +69,17 @@ def test_rpc():
)
def check(remote):
- # basic function checks.
dev = remote.webgpu(0)
- adata = np.random.uniform(size=n).astype(A.dtype)
+ # invoke the function
+ vm = relax.VirtualMachine(remote.system_lib(), device=dev)
+ adata = np.random.uniform(size=n).astype(dtype)
+ bdata = np.random.uniform(size=n).astype(dtype)
a = tvm.nd.array(adata, dev)
- b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
-
- np.testing.assert_equal(a.numpy(), adata)
- f1 = remote.system_lib()
- addone = f1.get_function("addone")
- addone(a, b)
- np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1),
atol=1e-5, rtol=1e-5)
- print("Test pass..")
+ b = tvm.nd.array(bdata, dev)
+ vm.set_input("main", a, b)
+ vm.invoke_stateful("main")
+ c = vm.get_outputs("main")
+ np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy())
check(remote)
diff --git a/web/tests/python/webgpu_rpc_test.py
b/web/tests/python/webgpu_rpc_test.py
index 6e34a8a2b3..986393e9d4 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/webgpu_rpc_test.py
@@ -23,7 +23,7 @@ Connect javascript end to the websocket port and connect to
the RPC.
import tvm
from tvm import te
from tvm import rpc
-from tvm.contrib import utils, emcc
+from tvm.contrib import utils, tvmjs
from tvm.relay.backend import Runtime
import numpy as np
@@ -52,7 +52,7 @@ def test_rpc():
temp = utils.tempdir()
wasm_path = temp.relpath("addone_gpu.wasm")
- fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+ fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
wasm_binary = open(wasm_path, "rb").read()
remote = rpc.connect(
diff --git a/web/tests/python/websock_rpc_test.py
b/web/tests/python/websock_rpc_test.py
index 7de5ee956e..19d5dc5748 100644
--- a/web/tests/python/websock_rpc_test.py
+++ b/web/tests/python/websock_rpc_test.py
@@ -23,7 +23,7 @@ Connect javascript end to the websocket port and connect to
the RPC.
import tvm
from tvm import te
from tvm import rpc
-from tvm.contrib import utils, emcc
+from tvm.contrib import utils, tvmjs
from tvm.relay.backend import Runtime
import numpy as np
@@ -48,7 +48,7 @@ def test_rpc():
temp = utils.tempdir()
wasm_path = temp.relpath("addone.wasm")
- fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+ fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm)
wasm_binary = open(wasm_path, "rb").read()