This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d5cd94595e [REFACTOR][FFI][Web] Upgrade Web Runtime to new FFI (#17946)
d5cd94595e is described below
commit d5cd94595e40214c76ef4215fb69aeb44fe16fd3
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun May 11 13:40:23 2025 -0400
[REFACTOR][FFI][Web] Upgrade Web Runtime to new FFI (#17946)
This PR refactors the web runtime to the new FFI protocol.
Tested through RPC tests and local tests.
---
ffi/include/tvm/ffi/c_api.h | 4 +-
ffi/include/tvm/ffi/dtype.h | 2 +-
ffi/include/tvm/ffi/error.h | 10 +-
ffi/src/ffi/dtype.cc | 4 +-
python/tvm/ffi/cython/base.pxi | 2 +-
python/tvm/ffi/cython/dtype.pxi | 2 +-
web/.eslintignore | 2 +
web/apps/node/example.js | 2 +-
web/emcc/tvmjs_support.cc | 46 +-
web/emcc/wasm_runtime.cc | 55 ++-
web/emcc/webgpu_runtime.cc | 22 +-
web/package.json | 4 +
web/src/asyncify.ts | 9 +
web/src/ctypes.ts | 296 +++++--------
web/src/environment.ts | 27 +-
web/src/memory.ts | 76 +++-
web/src/rpc_server.ts | 15 +-
web/src/runtime.ts | 813 +++++++++++++-----------------------
web/tests/node/test_ndarray.js | 2 +-
web/tests/node/test_object.js | 5 -
web/tests/node/test_packed_func.js | 53 ++-
web/tests/python/webgpu_rpc_test.py | 1 -
22 files changed, 604 insertions(+), 848 deletions(-)
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 1d495d9c5e..131f2e73e0 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -579,8 +579,10 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const
TVMFFIByteArray* str, DLDataType*
* \return 0 when success, nonzero when failure happens
* \note out is a String object that needs to be freed by the caller via
TVMFFIObjectFree.
The content of string can be accessed via TVMFFIObjectGetByteArrayPtr.
+
+ * \note The input dtype is a pointer to the DLDataType to avoid ABI
compatibility issues.
*/
-TVM_FFI_DLL int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle*
out);
+TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype,
TVMFFIObjectHandle* out);
//------------------------------------------------------------
// Section: Backend noexcept functions for internal use
diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h
index 99eb227ee1..a1a6b58afa 100644
--- a/ffi/include/tvm/ffi/dtype.h
+++ b/ffi/include/tvm/ffi/dtype.h
@@ -121,7 +121,7 @@ inline DLDataType StringToDLDataType(const String& str) {
inline String DLDataTypeToString(DLDataType dtype) {
TVMFFIObjectHandle out;
- TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(dtype, &out));
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
return
String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out)));
}
diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h
index 239a0e500b..de754bd6ea 100644
--- a/ffi/include/tvm/ffi/error.h
+++ b/ffi/include/tvm/ffi/error.h
@@ -51,6 +51,10 @@
#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1
#endif
+#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW
+#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0
+#endif
+
namespace tvm {
namespace ffi {
@@ -212,8 +216,10 @@ class ErrorBuilder {
*
* \endcode
*/
-#define TVM_FFI_THROW(ErrorKind) \
- ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE,
false).stream()
+#define TVM_FFI_THROW(ErrorKind) \
+ ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, \
+ TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \
+ .stream()
/*!
* \brief Explicitly log error in stderr and then throw the error.
diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc
index 7661ab4b97..cb0bd49597 100644
--- a/ffi/src/ffi/dtype.cc
+++ b/ffi/src/ffi/dtype.cc
@@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str,
DLDataType* out) {
TVM_FFI_SAFE_CALL_END();
}
-int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) {
+int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) {
TVM_FFI_SAFE_CALL_BEGIN();
- tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(dtype));
+ tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype));
*out =
tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str));
TVM_FFI_SAFE_CALL_END();
}
diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi
index 8fe23cd23b..8b9c1f3d94 100644
--- a/python/tvm/ffi/cython/base.pxi
+++ b/python/tvm/ffi/cython/base.pxi
@@ -150,7 +150,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil
int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex)
nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
- int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) nogil
+ int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle*
out) nogil
const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno,
const char* func) nogil;
int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t
require_alignment,
int32_t require_contiguous,
TVMFFIObjectHandle* out) nogil
diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi
index 30f9f274b4..80ec5d9364 100644
--- a/python/tvm/ffi/cython/dtype.pxi
+++ b/python/tvm/ffi/cython/dtype.pxi
@@ -94,7 +94,7 @@ cdef class DataType:
def __str__(self):
cdef TVMFFIObjectHandle dtype_str
cdef TVMFFIByteArray* bytes
- CHECK_CALL(TVMFFIDataTypeToString(self.cdtype, &dtype_str))
+ CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str))
bytes = TVMFFIBytesGetByteArrayPtr(dtype_str)
res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
CHECK_CALL(TVMFFIObjectFree(dtype_str))
diff --git a/web/.eslintignore b/web/.eslintignore
index f71ee79871..1549e07c25 100644
--- a/web/.eslintignore
+++ b/web/.eslintignore
@@ -1,2 +1,4 @@
dist
debug
+tvmjs_runtime_wasi.js
+lib
diff --git a/web/apps/node/example.js b/web/apps/node/example.js
index 580bbf57ab..62c9157c7c 100644
--- a/web/apps/node/example.js
+++ b/web/apps/node/example.js
@@ -31,7 +31,7 @@ const wasmSource = fs.readFileSync(path.join(wasmPath,
"tvmjs_runtime.wasm"));
tvmjs.instantiate(wasmSource, tvmjs.createPolyfillWASI())
.then((tvm) => {
tvm.beginScope();
- const log_info = tvm.getGlobalFunc("testing.log_info_str");
+ const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str");
log_info("hello world");
// List all the global functions from the runtime.
console.log("Runtime functions using EmccWASI\n",
tvm.listGlobalFuncNames());
diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc
index e50e6c37d3..1e35a1137f 100644
--- a/web/emcc/tvmjs_support.cc
+++ b/web/emcc/tvmjs_support.cc
@@ -28,12 +28,11 @@
#define TVM_LOG_STACK_TRACE 0
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
+#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
-#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
#include "../../src/runtime/rpc/rpc_local_session.h"
@@ -59,27 +58,33 @@ TVM_DLL void TVMWasmFreeSpace(void* data);
* \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
3A * \return 0 if success.
*/
-TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle,
TVMFunctionHandle* out);
+TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle*
out);
+
+/*!
+ * \brief Get the last error message.
+ * \return The last error message.
+ */
+TVM_DLL const char* TVMFFIWasmGetLastError();
// --- APIs to be implemented by the frontend. ---
+
/*!
- * \brief Wasm frontend packed function caller.
+ * \brief Wasm frontend new ffi call function caller.
*
+ * \param self The pointer to the ffi::Function.
* \param args The arguments
- * \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
- * \param ret The return value handle.
- * \param resource_handle The handle additional resource handle from front-end.
+ * \param result The return value handle.
* \return 0 if success, -1 if failure happens, set error via
TVMAPISetLastError.
*/
-extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args,
TVMRetValueHandle ret,
- void* resource_handle);
-
+extern int TVMFFIWasmSafeCall(void* self, const TVMFFIAny* args, int32_t
num_args,
+ TVMFFIAny* result);
/*!
- * \brief Wasm frontend resource finalizer.
- * \param resource_handle The pointer to the external resource.
+ * \brief Delete ffi::Function.
+ * \param self The pointer to the ffi::Function.
*/
-extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
+extern void TVMFFIWasmFunctionDeleter(void* self);
+
} // extern "C"
void* TVMWasmAllocSpace(int size) {
@@ -89,9 +94,14 @@ void* TVMWasmAllocSpace(int size) {
void TVMWasmFreeSpace(void* arr) { delete[] static_cast<int64_t*>(arr); }
-int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) {
- return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle,
TVMWasmPackedCFuncFinalizer,
- out);
+int TVMFFIWasmFunctionCreate(void* self, TVMFunctionHandle* out) {
+ return TVMFFIFunctionCreate(self, TVMFFIWasmSafeCall,
TVMFFIWasmFunctionDeleter, out);
+}
+
+const char* TVMFFIWasmGetLastError() {
+ static thread_local std::string last_error;
+ last_error = ::tvm::ffi::details::MoveFromSafeCallRaised().what();
+ return last_error.c_str();
}
namespace tvm {
@@ -291,7 +301,7 @@ class AsyncLocalSession : public LocalSession {
}
};
-TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
+TVM_FFI_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
return CreateRPCSessionModule(std::make_shared<AsyncLocalSession>());
});
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index b208daed51..728e1c648c 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -27,9 +27,9 @@
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
#define TVM_FFI_USE_LIBBACKTRACE 0
+#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
-#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>
#include "src/runtime/c_runtime_api.cc"
@@ -107,45 +107,24 @@ void LogMessageImpl(const std::string& file, int lineno,
int level, const std::s
} // namespace detail
-TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](ffi::PackedArgs args,
ffi::Any* ret) {
- *ret = args[0];
-});
-
-TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](ffi::PackedArgs args,
ffi::Any* ret) {
- (args[0].cast<ffi::Function>()).CallPacked(args.Slice(1), ret);
-});
-
-TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](ffi::PackedArgs
args, ffi::Any* ret) {
- *ret = args[0].cast<String>();
-});
-
-TVM_REGISTER_GLOBAL("testing.log_info_str")
+TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.call")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
- LOG(INFO) << args[0].cast<String>();
+ (args[0].cast<ffi::Function>()).CallPacked(args.Slice(1), ret);
});
-TVM_REGISTER_GLOBAL("testing.log_fatal_str")
+TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.log_info_str")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
- LOG(FATAL) << args[0].cast<String>();
+ LOG(INFO) << args[0].cast<String>();
});
-TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x +
1; });
+TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.add_one").set_body_typed([](int x) {
return x + 1; });
-TVM_REGISTER_GLOBAL("testing.wrap_callback")
+TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.wrap_callback")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ffi::Function pf = args[0].cast<ffi::Function>();
*ret = ffi::TypedFunction<void()>([pf]() { pf(); });
});
-// internal function used for debug and testing purposes
-TVM_REGISTER_GLOBAL("testing.object_use_count")
- .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
- auto obj = args[0].cast<ffi::ObjectRef>();
- // subtract the current one because we always copy
- // and get another value.
- *ret = (obj.use_count() - 1);
- });
-
void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string
format, std::string dtype) {
if (format == "f32-to-bf16" && dtype == "float32") {
std::vector<uint16_t> buffer(bytes.length() / 2);
@@ -167,10 +146,10 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string
bytes, std::string format,
}
}
-TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
+TVM_FFI_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
// Concatenate n TVMArrays
-TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat")
+TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
std::vector<Any> data;
for (int i = 0; i < args.size(); ++i) {
@@ -220,7 +199,7 @@ NDArray ConcatEmbeddings(const std::vector<NDArray>&
embeddings) {
}
// Concatenate n NDArrays
-TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
+TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
std::vector<NDArray> embeddings;
for (int i = 0; i < args.size(); ++i) {
@@ -230,5 +209,19 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
*ret = result;
});
+TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyFromBytes")
+ .set_body_typed([](NDArray nd, TVMFFIByteArray* bytes) {
+ nd.CopyFromBytes(bytes->data, bytes->size);
+ });
+
+TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyToBytes")
+ .set_body_typed([](NDArray nd) -> ffi::Bytes {
+ size_t size = GetDataSize(*(nd.operator->()));
+ std::string bytes;
+ bytes.resize(size);
+ nd.CopyToBytes(bytes.data(), size);
+ return ffi::Bytes(bytes);
+ });
+
} // namespace runtime
} // namespace tvm
diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc
index 3d74d77f14..00b1db266a 100644
--- a/web/emcc/webgpu_runtime.cc
+++ b/web/emcc/webgpu_runtime.cc
@@ -26,13 +26,11 @@
#define TVM_LOG_STACK_TRACE 0
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
+#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
-#include <dmlc/thread_local.h>
-#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
#include <iostream>
#include <string>
@@ -152,7 +150,10 @@ typedef dmlc::ThreadLocalStore<WebGPUThreadEntry>
WebGPUThreadStore;
WebGPUThreadEntry::WebGPUThreadEntry()
: pool(static_cast<DLDeviceType>(kDLWebGPU), WebGPUDeviceAPI::Global()) {}
-WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return
WebGPUThreadStore::Get(); }
+WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() {
+ static thread_local WebGPUThreadEntry inst = WebGPUThreadEntry();
+ return &inst;
+}
class WebGPUModuleNode final : public runtime::ModuleNode {
public:
@@ -241,12 +242,13 @@ Module WebGPUModuleLoadBinary(void* strm) {
}
// for now webgpu is hosted via a vulkan module.
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary);
+TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary);
-TVM_REGISTER_GLOBAL("device_api.webgpu").set_body_packed([](ffi::PackedArgs
args, ffi::Any* rv) {
- DeviceAPI* ptr = WebGPUDeviceAPI::Global();
- *rv = static_cast<void*>(ptr);
-});
+TVM_FFI_REGISTER_GLOBAL("device_api.webgpu")
+ .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
+ DeviceAPI* ptr = WebGPUDeviceAPI::Global();
+ *rv = static_cast<void*>(ptr);
+ });
} // namespace runtime
} // namespace tvm
diff --git a/web/package.json b/web/package.json
index 583232d209..b4fc25e12f 100644
--- a/web/package.json
+++ b/web/package.json
@@ -45,5 +45,9 @@
"typedoc-plugin-missing-exports": "2.0.0",
"typescript": "^4.9.5",
"ws": "^7.2.5"
+ },
+ "dependencies": {
+ "audit": "^0.0.6",
+ "fix": "^0.0.6"
}
}
diff --git a/web/src/asyncify.ts b/web/src/asyncify.ts
index 703dbbf80a..6074a559e0 100644
--- a/web/src/asyncify.ts
+++ b/web/src/asyncify.ts
@@ -70,6 +70,15 @@ export class AsyncifyHandler {
return this.exports.asyncify_stop_rewind !== undefined;
}
+ /**
+ * Get the current asynctify state
+ *
+ * @returns The current asynctify state
+ */
+ isNormalStackState(): boolean {
+ return this.state == AsyncifyStateKind.None;
+ }
+
/**
* Get the current asynctify state
*
diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts
index c4941f07d5..c9a5e263d5 100644
--- a/web/src/ctypes.ts
+++ b/web/src/ctypes.ts
@@ -27,231 +27,165 @@ export type Pointer = number;
/** A pointer offset, need to add a base address to get a valid ptr. */
export type PtrOffset = number;
-// -- TVM runtime C API --
/**
- * const char *TVMGetLastError();
- */
-export type FTVMGetLastError = () => Pointer;
-
-/**
- * void TVMAPISetLastError(const char* msg);
- */
-export type FTVMAPISetLastError = (msg: Pointer) => void;
-
-/**
- * int TVMModGetFunction(TVMModuleHandle mod,
- * const char* func_name,
- * int query_imports,
- * TVMFunctionHandle *out);
- */
-export type FTVMModGetFunction = (
- mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) =>
number;
-/**
- * int TVMModImport(TVMModuleHandle mod,
- * TVMModuleHandle dep);
- */
-export type FTVMModImport = (mod: Pointer, dep: Pointer) => number;
-
-/**
- * int TVMModFree(TVMModuleHandle mod);
- */
-export type FTVMModFree = (mod: Pointer) => number;
-
-/**
- * int TVMFuncFree(TVMFunctionHandle func);
- */
-export type FTVMFuncFree = (func: Pointer) => number;
-
-/**
- * int TVMFuncCall(TVMFunctionHandle func,
- * TVMValue* arg_values,
- * int* type_codes,
- * int num_args,
- * TVMValue* ret_val,
- * int* ret_type_code);
- */
-export type FTVMFuncCall = (
- func: Pointer, argValues: Pointer, typeCode: Pointer,
- nargs: number, retValue: Pointer, retCode: Pointer) => number;
-
-/**
- * int TVMCFuncSetReturn(TVMRetValueHandle ret,
- * TVMValue* value,
- * int* type_code,
- * int num_ret);
- */
-export type FTVMCFuncSetReturn = (
- ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number;
-
-/**
- * int TVMCbArgToReturn(TVMValue* value, int* code);
- */
-export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number;
-
-/**
- * int TVMFuncListGlobalNames(int* outSize, const char*** outArray);
+ * Size of common data types.
*/
-export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) =>
number;
+export const enum SizeOf {
+ U8 = 1,
+ U16 = 2,
+ I32 = 4,
+ I64 = 8,
+ F32 = 4,
+ F64 = 8,
+ TVMValue = 8,
+ TVMFFIAny = 8 * 2,
+ DLDataType = I32,
+ DLDevice = I32 + I32,
+ ObjectHeader = 8 * 2,
+}
+//---------------The new TVM FFI---------------
/**
- * int TVMFuncRegisterGlobal(
- * const char* name, TVMFunctionHandle f, int override);
- */
-export type FTVMFuncRegisterGlobal = (
- name: Pointer, f: Pointer, override: number) => number;
+ * Type Index in new TVM FFI.
+ *
+ * We are keeping the same style as C API here.
+ */
+export const enum TypeIndex {
+ kTVMFFINone = 0,
+ /*! \brief POD int value */
+ kTVMFFIInt = 1,
+ /*! \brief POD bool value */
+ kTVMFFIBool = 2,
+ /*! \brief POD float value */
+ kTVMFFIFloat = 3,
+ /*! \brief Opaque pointer object */
+ kTVMFFIOpaquePtr = 4,
+ /*! \brief DLDataType */
+ kTVMFFIDataType = 5,
+ /*! \brief DLDevice */
+ kTVMFFIDevice = 6,
+ /*! \brief DLTensor* */
+ kTVMFFIDLTensorPtr = 7,
+ /*! \brief const char**/
+ kTVMFFIRawStr = 8,
+ /*! \brief TVMFFIByteArray* */
+ kTVMFFIByteArrayPtr = 9,
+ /*! \brief R-value reference to ObjectRef */
+ kTVMFFIObjectRValueRef = 10,
+ /*! \brief Start of statically defined objects. */
+ kTVMFFIStaticObjectBegin = 64,
+ /*!
+ * \brief Object, all objects starts with TVMFFIObject as its header.
+ * \note We will also add other fields
+ */
+ kTVMFFIObject = 64,
+ /*!
+ * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
+ */
+ kTVMFFIStr = 65,
+ /*!
+ * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
+ */
+ kTVMFFIBytes = 66,
+ /*! \brief Error object. */
+ kTVMFFIError = 67,
+ /*! \brief Function object. */
+ kTVMFFIFunction = 68,
+ /*! \brief Array object. */
+ kTVMFFIArray = 69,
+ /*! \brief Map object. */
+ kTVMFFIMap = 70,
+ /*!
+ * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t },
... }
+ */
+ kTVMFFIShape = 71,
+ /*!
+ * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... }
+ */
+ kTVMFFINDArray = 72,
+ /*! \brief Runtime module object. */
+ kTVMFFIModule = 73,
+}
-/**
- *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
- */
-export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number;
+// -- TVM Wasm Auxiliary C API --
-/**
- * int TVMArrayAlloc(const tvm_index_t* shape,
- * int ndim,
- * int dtype_code,
- * int dtype_bits,
- * int dtype_lanes,
- * int device_type,
- * int device_id,
- * TVMArrayHandle* out);
- */
-export type FTVMArrayAlloc = (
- shape: Pointer, ndim: number,
- dtypeCode: number, dtypeBits: number,
- dtypeLanes: number, deviceType: number, deviceId: number,
- out: Pointer) => number;
+/** void* TVMWasmAllocSpace(int size); */
+export type FTVMWasmAllocSpace = (size: number) => Pointer;
-/**
- * int TVMArrayFree(TVMArrayHandle handle);
- */
-export type FTVMArrayFree = (handle: Pointer) => number;
+/** void TVMWasmFreeSpace(void* data); */
+export type FTVMWasmFreeSpace = (ptr: Pointer) => void;
-/**
- * int TVMArrayCopyFromBytes(TVMArrayHandle handle,
- * void* data,
- * size_t nbytes);
- */
-export type FTVMArrayCopyFromBytes = (
- handle: Pointer, data: Pointer, nbytes: number) => number;
+/** const char* TVMFFIWasmGetLastError(); */
+export type FTVMFFIWasmGetLastError = () => Pointer;
/**
- * int TVMArrayCopyToBytes(TVMArrayHandle handle,
- * void* data,
- * size_t nbytes);
+ * int TVMFFIWasmSafeCallType(void* self, const TVMFFIAny* args,
+ * int32_t num_args, TVMFFIAny* result);
*/
-export type FTVMArrayCopyToBytes = (
- handle: Pointer, data: Pointer, nbytes: number) => number;
+export type FTVMFFIWasmSafeCallType = (
+ self: Pointer, args: Pointer, num_args: number,
+ result: Pointer) => number;
/**
- * int TVMArrayCopyFromTo(TVMArrayHandle from,
- * TVMArrayHandle to,
- * TVMStreamHandle stream);
+ * int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out);
*/
-export type FTVMArrayCopyFromTo = (
- from: Pointer, to: Pointer, stream: Pointer) => number;
+export type FTVMFFIWasmFunctionCreate = (
+ resource_handle: Pointer, out: Pointer) => number;
/**
- * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
+ * void TVMFFIWasmFunctionDeleter(void* self);
*/
-export type FTVMSynchronize = (
- deviceType: number, deviceId: number, stream: Pointer) => number;
+export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void;
/**
- * typedef int (*TVMBackendPackedCFunc)(TVMValue* args,
- * int* type_codes,
- * int num_args,
- * TVMValue* out_ret_value,
- * int* out_ret_tcode);
+ * int TVMFFIObjectFree(TVMFFIObjectHandle obj);
*/
-export type FTVMBackendPackedCFunc = (
- argValues: Pointer, argCodes: Pointer, nargs: number,
- outValue: Pointer, outCode: Pointer) => number;
-
+export type FTVMFFIObjectFree = (obj: Pointer) => number;
/**
- * int TVMObjectFree(TVMObjectHandle obj);
+ * int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t*
out_tindex);
*/
-export type FTVMObjectFree = (obj: Pointer) => number;
+export type FTVMFFITypeKeyToIndex = (type_key: Pointer, out_tindex: Pointer)
=> number;
/**
- * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
+ * int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out);
*/
-export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) =>
number;
+export type FTVMFFIAnyViewToOwnedAny = (any_view: Pointer, out: Pointer) =>
number;
/**
- * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key);
+ * void TVMFFIErrorSetRaisedByCStr(const char* kind, const char* message);
*/
-export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key:
Pointer) => number;
+export type FTVMFFIErrorSetRaisedByCStr = (kind: Pointer, message: Pointer) =>
void;
/**
- * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
+ * int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle
f,
+ * int override);
*/
-export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer)
=> number;
-
-// -- TVM Wasm Auxiliary C API --
-
-/** void* TVMWasmAllocSpace(int size); */
-export type FTVMWasmAllocSpace = (size: number) => Pointer;
-
-/** void TVMWasmFreeSpace(void* data); */
-export type FTVMWasmFreeSpace = (ptr: Pointer) => void;
+export type FTVMFFIFunctionSetGlobal = (name: Pointer, f: Pointer, override:
number) => number;
/**
- * int TVMWasmPackedCFunc(TVMValue* args,
- * int* type_codes,
- * int num_args,
- * TVMRetValueHandle ret,
- * void* resource_handle);
+ * int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name,
TVMFFIObjectHandle* out);
*/
-export type FTVMWasmPackedCFunc = (
- args: Pointer, typeCodes: Pointer, nargs: number,
- ret: Pointer, resourceHandle: Pointer) => number;
+export type FTVMFFIFunctionGetGlobal = (name: Pointer, out: Pointer) => number;
/**
- * int TVMWasmFuncCreateFromCFunc(void* resource_handle,
- * TVMFunctionHandle *out);
+ * int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t
num_args,
+ * TVMFFIAny* result);
*/
-export type FTVMWasmFuncCreateFromCFunc = (
- resource: Pointer, out: Pointer) => number;
+export type FTVMFFIFunctionCall = (func: Pointer, args: Pointer, num_args:
number,
+ result: Pointer) => number;
/**
- * void TVMWasmPackedCFuncFinalizer(void* resource_handle);
+ * int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out);
*/
-export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void;
+export type FTVMFFIDataTypeFromString = (str: Pointer, out: Pointer) => number;
/**
- * Size of common data types.
+ * int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle*
out);
*/
-export const enum SizeOf {
- U8 = 1,
- U16 = 2,
- I32 = 4,
- I64 = 8,
- F32 = 4,
- F64 = 8,
- TVMValue = 8,
- DLDataType = I32,
- DLDevice = I32 + I32,
-}
+export type FTVMFFIDataTypeToString = (dtype: Pointer, out: Pointer) => number;
/**
- * Argument Type code in TVM FFI.
+ * TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index);
*/
-export const enum ArgTypeCode {
- Int = 0,
- UInt = 1,
- Float = 2,
- TVMOpaqueHandle = 3,
- Null = 4,
- TVMDataType = 5,
- DLDevice = 6,
- TVMDLTensorHandle = 7,
- TVMObjectHandle = 8,
- TVMModuleHandle = 9,
- TVMPackedFuncHandle = 10,
- TVMStr = 11,
- TVMBytes = 12,
- TVMNDArrayHandle = 13,
- TVMObjectRValueRefArg = 14,
- TVMArgBool = 15,
-}
+export type FTVMFFIGetTypeInfo = (type_index: number) => Pointer;
diff --git a/web/src/environment.ts b/web/src/environment.ts
index 42a873f128..01e19a1c18 100644
--- a/web/src/environment.ts
+++ b/web/src/environment.ts
@@ -75,7 +75,7 @@ export class Environment implements LibraryProvider {
* We maintain a separate table so that we can have un-limited amount
* of functions that do not maps to the address space.
*/
- packedCFuncTable: Array<ctypes.FTVMWasmPackedCFunc | undefined> = [
+ packedCFuncTable: Array<ctypes.FTVMFFIWasmSafeCallType | undefined> = [
undefined,
];
/**
@@ -115,28 +115,27 @@ export class Environment implements LibraryProvider {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
"emscripten_notify_memory_growth": (index: number): void => {}
};
- const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = (
+ const wasmSafeCall: ctypes.FTVMFFIWasmSafeCallType = (
+ self: Pointer,
args: Pointer,
- typeCodes: Pointer,
- nargs: number,
- ret: Pointer,
- resourceHandle: Pointer
+ num_args: number,
+ result: Pointer
): number => {
- const cfunc = this.packedCFuncTable[resourceHandle];
+ const cfunc = this.packedCFuncTable[self];
assert(cfunc !== undefined);
- return cfunc(args, typeCodes, nargs, ret, resourceHandle);
+ return cfunc(self, args, num_args, result);
};
- const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = (
- resourceHandle: Pointer
+ const wasmFunctionDeleter: ctypes.FTVMFFIWasmFunctionDeleter = (
+ self: Pointer
): void => {
- this.packedCFuncTable[resourceHandle] = undefined;
- this.packedCFuncTableFreeId.push(resourceHandle);
+ this.packedCFuncTable[self] = undefined;
+ this.packedCFuncTableFreeId.push(self);
};
const newEnv = {
- TVMWasmPackedCFunc: wasmPackedCFunc,
- TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer,
+ "TVMFFIWasmSafeCall": wasmSafeCall,
+ "TVMFFIWasmFunctionDeleter": wasmFunctionDeleter,
"__console_log": (msg: string): void => {
this.logger(msg);
}
diff --git a/web/src/memory.ts b/web/src/memory.ts
index b0d4ff3bf1..850f3bd371 100644
--- a/web/src/memory.ts
+++ b/web/src/memory.ts
@@ -137,16 +137,6 @@ export class Memory {
result.set(this.viewU8.slice(ptr, ptr + numBytes));
return result;
}
- /**
- * Load TVMByteArray from ptr.
- *
- * @param ptr The address of the header.
- */
- loadTVMBytes(ptr: Pointer): Uint8Array {
- const data = this.loadPointer(ptr);
- const length = this.loadUSize(ptr + this.sizeofPtr());
- return this.loadRawBytes(data, length);
- }
/**
* Load null-terminated C-string from ptr.
* @param ptr The head address
@@ -178,7 +168,56 @@ export class Memory {
}
this.viewU8.set(bytes, ptr);
}
-
+ // the following functions are related to TVM FFI
+ /**
+ * Load the object type index from the object handle.
+ * @param objectHandle The handle of the object.
+ * @returns The object type index.
+ */
+ loadObjectTypeIndex(objectHandle: Pointer): number {
+ return this.loadI32(objectHandle);
+ }
+ /**
+ * Load the type key from the type info pointer.
+ * @param typeInfoPtr The pointer to the type info.
+ * @returns The type key.
+ */
+ loadTypeInfoTypeKey(typeInfoPtr: Pointer): string {
+ const typeKeyPtr = typeInfoPtr + 2 * SizeOf.I32;
+ return this.loadByteArrayAsString(typeKeyPtr);
+ }
+ /**
+ * Load bytearray as string from ptr.
+ * @param byteArrayPtr The head address of the bytearray.
+ */
+ loadByteArrayAsString(byteArrayPtr: Pointer): string {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const ptr = this.loadPointer(byteArrayPtr);
+ const length = this.loadUSize(byteArrayPtr + this.sizeofPtr());
+ // NOTE: the views are still valid for read.
+ const ret = [];
+ for (let i = 0; i < length; i++) {
+ ret.push(String.fromCharCode(this.viewU8[ptr + i]));
+ }
+ return ret.join("");
+ }
+ /**
+ * Load bytearray as bytes from ptr.
+ * @param byteArrayPtr The head address of the bytearray.
+ */
+ loadByteArrayAsBytes(byteArrayPtr: Pointer): Uint8Array {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const ptr = this.loadPointer(byteArrayPtr);
+ const length = this.loadUSize(byteArrayPtr + this.sizeofPtr());
+ const result = new Uint8Array(length);
+ result.set(this.viewU8.slice(ptr, ptr + length));
+ return result;
+}
+ // private functions
/**
* Update memory view after the memory growth.
*/
@@ -365,6 +404,21 @@ export class CachedCallStack implements Disposable {
this.viewU8.set(bytes, offset);
}
+ /**
+ * Allocate a byte array for a string and return the offset of the byte
array.
+ * @param data The string to allocate.
+ * @returns The offset of the byte array.
+ */
+ allocByteArrayForString(data: string): PtrOffset {
+ const dataUint8: Uint8Array = StringToUint8Array(data);
+ // Note: size of size_t equals sizeof ptr.
+ const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2);
+ const dataOffset = this.allocRawBytes(dataUint8.length);
+ this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length);
+ this.storeRawBytes(dataOffset, dataUint8);
+ this.addressToSetTargetValue.push([headerOffset, dataOffset]);
+ return headerOffset;
+ }
/**
* Allocate then set C-String pointer to the offset.
* This function will call into allocBytes to allocate necessary data.
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
index 46848a6dec..1e3af6f643 100644
--- a/web/src/rpc_server.ts
+++ b/web/src/rpc_server.ts
@@ -17,7 +17,7 @@
* under the License.
*/
-import { SizeOf, ArgTypeCode } from "./ctypes";
+import { SizeOf, TypeIndex } from "./ctypes";
import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
import { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
import * as compact from "./compact";
@@ -228,21 +228,16 @@ export class RPCServer {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const ver = Uint8ArrayToString(reader.readByteArray());
const nargs = reader.readU32();
- const tcodes = [];
const args = [];
for (let i = 0; i < nargs; ++i) {
- tcodes.push(reader.readU32());
- }
-
- for (let i = 0; i < nargs; ++i) {
- const tcode = tcodes[i];
- if (tcode === ArgTypeCode.TVMStr) {
+ const typeIndex = reader.readU32();
+ if (typeIndex === TypeIndex.kTVMFFIRawStr) {
const str = Uint8ArrayToString(reader.readByteArray());
args.push(str);
- } else if (tcode === ArgTypeCode.TVMBytes) {
+ } else if (typeIndex === TypeIndex.kTVMFFIByteArrayPtr) {
args.push(reader.readByteArray());
} else {
- throw new Error("cannot support type code " + tcode);
+ throw new Error("cannot support type index " + typeIndex);
}
}
this.onInitServer(args, header, body);
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 5c47c0e7a5..47902086f5 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -20,7 +20,7 @@
/**
* TVM JS Wasm Runtime library.
*/
-import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes";
+import { Pointer, PtrOffset, SizeOf, TypeIndex } from "./ctypes";
import { Disposable } from "./types";
import { Memory, CachedCallStack } from "./memory";
import { assert, StringToUint8Array, LinearCongruentialGenerator } from
"./support";
@@ -90,8 +90,8 @@ class FFILibrary implements Disposable {
checkCall(code: number): void {
if (code != 0) {
const msgPtr = (this.exports
- .TVMGetLastError as ctypes.FTVMGetLastError)();
- throw new Error("TVMError: " + this.memory.loadCString(msgPtr));
+ .TVMFFIWasmGetLastError as ctypes.FTVMFFIWasmGetLastError)();
+ throw new Error(this.memory.loadCString(msgPtr));
}
}
@@ -153,6 +153,13 @@ class FFILibrary implements Disposable {
* Manages extra runtime context for the runtime.
*/
class RuntimeContext implements Disposable {
+ functionListGlobalNamesFunctor: PackedFunc;
+ moduleGetFunction: PackedFunc;
+ moduleImport: PackedFunc;
+ ndarrayEmpty: PackedFunc;
+ ndarrayCopyFromTo: PackedFunc;
+ ndarrayCopyFromJSBytes: PackedFunc;
+ ndarrayCopyToJSBytes: PackedFunc;
arrayGetItem: PackedFunc;
arrayGetSize: PackedFunc;
arrayMake: PackedFunc;
@@ -173,10 +180,21 @@ class RuntimeContext implements Disposable {
applyPresenceAndFrequencyPenalty: PackedFunc;
applySoftmaxWithTemperature: PackedFunc;
concatEmbeddings: PackedFunc | undefined;
-
+ bool: PackedFunc;
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
- constructor(getGlobalFunc: (name: string) => PackedFunc) {
+ constructor(
+ getGlobalFunc: (name: string) => PackedFunc
+ ) {
+ this.functionListGlobalNamesFunctor = getGlobalFunc(
+ "ffi.FunctionListGlobalNamesFunctor"
+ );
+ this.moduleGetFunction = getGlobalFunc("runtime.ModuleGetFunction");
+ this.moduleImport = getGlobalFunc("runtime.ModuleImport");
+ this.ndarrayEmpty = getGlobalFunc("runtime.TVMArrayAllocWithScope");
+ this.ndarrayCopyFromTo = getGlobalFunc("runtime.TVMArrayCopyFromTo");
+ this.ndarrayCopyFromJSBytes =
getGlobalFunc("tvmjs.runtime.NDArrayCopyFromBytes");
+ this.ndarrayCopyToJSBytes =
getGlobalFunc("tvmjs.runtime.NDArrayCopyToBytes");
this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem");
this.arrayGetSize = getGlobalFunc("runtime.ArraySize");
this.arrayMake = getGlobalFunc("runtime.Array");
@@ -189,18 +207,14 @@ class RuntimeContext implements Disposable {
this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
this.paramModuleFromCache =
getGlobalFunc("vm.builtin.param_module_from_cache");
this.paramModuleFromCacheByName =
getGlobalFunc("vm.builtin.param_module_from_cache_by_name");
- this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
+ this.makeShapeTuple = getGlobalFunc("ffi.Shape");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits =
getGlobalFunc("vm.builtin.sample_top_p_from_logits");
this.sampleTopPFromProb =
getGlobalFunc("vm.builtin.sample_top_p_from_prob");
this.applyRepetitionPenalty =
getGlobalFunc("vm.builtin.apply_repetition_penalty");
this.applyPresenceAndFrequencyPenalty =
getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
this.applySoftmaxWithTemperature =
getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
- try {
- this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings");
- } catch {
- // TODO: remove soon. Older artifacts do not have this, try-catch for
backward compatibility.
- }
+ this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings");
}
dispose(): void {
@@ -306,35 +320,6 @@ export class Scalar {
}
}
-/**
- * Cell holds the PackedFunc object.
- */
-class PackedFuncCell implements Disposable {
- private handle: Pointer;
- private lib: FFILibrary;
-
- constructor(handle: Pointer, lib: FFILibrary) {
- this.handle = handle;
- this.lib = lib;
- }
-
- dispose(): void {
- if (this.handle != 0) {
- this.lib.checkCall(
- (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle)
- );
- this.handle = 0;
- }
- }
-
- getHandle(requireNotNull = true): Pointer {
- if (requireNotNull && this.handle === 0) {
- throw Error("PackedFunc has already been disposed");
- }
- return this.handle;
- }
-}
-
const DeviceEnumToStr: Record<number, string> = {
1: "cpu",
2: "cuda",
@@ -392,7 +377,7 @@ export class DLDevice {
toString(): string {
return (
- DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")"
+ DeviceEnumToStr[this.deviceType] + ":" + this.deviceId.toString()
);
}
}
@@ -444,12 +429,78 @@ export class DLDataType {
}
}
+/**
+ * Generic object base
+ */
+export class TVMObject implements Disposable {
+ protected handle: Pointer;
+ protected lib: FFILibrary;
+ protected ctx: RuntimeContext;
+
+ constructor(
+ handle: Pointer,
+ lib: FFILibrary,
+ ctx: RuntimeContext
+ ) {
+ this.handle = handle;
+ this.lib = lib;
+ this.ctx = ctx;
+ }
+
+ dispose(): void {
+ if (this.handle != 0) {
+ this.lib.checkCall(
+ (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(this.handle)
+ );
+ this.handle = 0;
+ }
+ }
+
+ /**
+ * Get handle of module, check it is not null.
+ *
+ * @param requireNotNull require handle is not null.
+ * @returns The handle.
+ */
+ getHandle(requireNotNull = true): Pointer {
+ if (requireNotNull && this.handle === 0) {
+ throw Error("Object has already been disposed");
+ }
+ return this.handle;
+ }
+
+ /** get the type index of the object */
+ typeIndex(): number {
+ if (this.handle === 0) {
+ throw Error("The current Object has already been disposed");
+ }
+ return this.lib.memory.loadObjectTypeIndex(this.handle);
+ }
+
+ /** get the type key of the object */
+ typeKey(): string {
+ const type_index = this.typeIndex();
+ const typeInfoPtr = (this.lib.exports.TVMFFIGetTypeInfo as
ctypes.FTVMFFIGetTypeInfo)(
+ type_index
+ );
+ return this.lib.memory.loadTypeInfoTypeKey(typeInfoPtr);
+ }
+}
+
+/**
+ * Cell holds the PackedFunc object.
+ */
+class PackedFuncCell extends TVMObject {
+ constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) {
+ super(handle, lib, ctx);
+ }
+}
+
/**
* n-dimnesional array.
*/
-export class NDArray implements Disposable {
- /** Internal array handle. */
- private handle: Pointer;
+
+export class NDArray extends TVMObject {
/** Number of dimensions. */
ndim: number;
/** Data type of the array. */
@@ -463,16 +514,14 @@ export class NDArray implements Disposable {
private byteOffset: number;
private dltensor: Pointer;
private dataPtr: Pointer;
- private lib: FFILibrary;
- private ctx: RuntimeContext;
private dlDataType: DLDataType;
- constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx:
RuntimeContext) {
- this.handle = handle;
+ constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext, isView:
boolean) {
+ // if the array is a view, we need to create a new object with a null
handle
+ // so dispose won't trigger memory free
+ const objectHandle = isView ? 0 : handle;
+ super(objectHandle, lib, ctx);
this.isView = isView;
- this.lib = lib;
- this.ctx = ctx;
-
if (this.isView) {
this.dltensor = handle;
} else {
@@ -535,20 +584,6 @@ export class NDArray implements Disposable {
/*relative_byte_offset=*/ new Scalar(0, "int"),
);
}
-
- /**
- * Get handle of ndarray, check it is not null.
- *
- * @param requireNotNull require handle is not null.
- * @returns The handle.
- */
- getHandle(requireNotNull = true): Pointer {
- if (requireNotNull && this.handle === 0) {
- throw Error("NDArray has already been disposed");
- }
- return this.handle;
- }
-
/**
* Get dataPtr of NDarray
*
@@ -561,14 +596,6 @@ export class NDArray implements Disposable {
return this.dataPtr;
}
- dispose(): void {
- if (this.handle != 0 && !this.isView) {
- this.lib.checkCall(
- (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle)
- );
- this.handle = 0;
- }
- }
/**
* Copy data from another NDArray or javascript array.
* The number of elements must match.
@@ -581,13 +608,7 @@ export class NDArray implements Disposable {
Int32Array | Int8Array | Uint8Array | Uint8ClampedArray
): this {
if (data instanceof NDArray) {
- this.lib.checkCall(
- (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)(
- data.getHandle(),
- this.getHandle(),
- 0
- )
- );
+ this.ctx.ndarrayCopyFromTo(data, this);
return this;
} else {
const size = this.shape.reduce((a, b) => {
@@ -639,21 +660,7 @@ export class NDArray implements Disposable {
if (nbytes != data.length) {
throw new Error("Expect the data's length equals nbytes=" + nbytes);
}
-
- const stack = this.lib.getOrAllocCallStack();
-
- const tempOffset = stack.allocRawBytes(nbytes);
- const tempPtr = stack.ptrFromOffset(tempOffset);
- this.lib.memory.storeRawBytes(tempPtr, data);
- this.lib.checkCall(
- (this.lib.exports.TVMArrayCopyFromBytes as
ctypes.FTVMArrayCopyFromBytes)(
- this.getHandle(),
- tempPtr,
- nbytes
- )
- );
-
- this.lib.recycleCallStack(stack);
+ this.ctx.ndarrayCopyFromJSBytes(this, data);
return this;
}
/**
@@ -664,26 +671,7 @@ export class NDArray implements Disposable {
if (this.device.deviceType != DeviceStrToEnum.cpu) {
throw new Error("Can only sync copy CPU array, use
cpu_arr.copyfrom(gpu_arr) then sync instead.");
}
- const size = this.shape.reduce((a, b) => {
- return a * b;
- }, 1);
-
- const nbytes = this.dlDataType.numStorageBytes() * size;
- const stack = this.lib.getOrAllocCallStack();
-
- const tempOffset = stack.allocRawBytes(nbytes);
- const tempPtr = stack.ptrFromOffset(tempOffset);
- this.lib.checkCall(
- (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)(
- this.getHandle(),
- tempPtr,
- nbytes
- )
- );
- const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes);
-
- this.lib.recycleCallStack(stack);
- return ret;
+ return this.ctx.ndarrayCopyToJSBytes(this) as Uint8Array;
}
/**
@@ -709,52 +697,22 @@ export class NDArray implements Disposable {
}
private getDLTensorFromArrayHandle(handle: Pointer): Pointer {
- // Note: this depends on the NDArray C ABI.
- // keep this function in case of ABI change.
- return handle;
+ return handle + SizeOf.ObjectHeader;
}
}
+
/**
* Runtime Module.
*/
-export class Module implements Disposable {
- private handle: Pointer;
- private lib: FFILibrary;
- private makePackedFunc: (ptr: Pointer) => PackedFunc;
-
+export class Module extends TVMObject {
constructor(
handle: Pointer,
lib: FFILibrary,
- makePackedFunc: (ptr: Pointer) => PackedFunc
+ ctx: RuntimeContext,
) {
- this.handle = handle;
- this.lib = lib;
- this.makePackedFunc = makePackedFunc;
- }
-
- dispose(): void {
- if (this.handle != 0) {
- this.lib.checkCall(
- (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle)
- );
- this.handle = 0;
- }
- }
-
- /**
- * Get handle of module, check it is not null.
- *
- * @param requireNotNull require handle is not null.
- * @returns The handle.
- */
- getHandle(requireNotNull = true): Pointer {
- if (requireNotNull && this.handle === 0) {
- throw Error("Module has already been disposed");
- }
- return this.handle;
+ super(handle, lib, ctx);
}
-
/**
* Get a function in the module.
* @param name The name of the function.
@@ -762,33 +720,7 @@ export class Module implements Disposable {
* @returns The result function.
*/
getFunction(name: string, queryImports = true): PackedFunc {
- if (this.handle === 0) {
- throw Error("Module has already been disposed");
- }
- const stack = this.lib.getOrAllocCallStack();
- const nameOffset = stack.allocRawBytes(name.length + 1);
- stack.storeRawBytes(nameOffset, StringToUint8Array(name));
-
- const outOffset = stack.allocPtrArray(1);
- const outPtr = stack.ptrFromOffset(outOffset);
-
- stack.commitToWasmMemory(outOffset);
-
- this.lib.checkCall(
- (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)(
- this.getHandle(),
- stack.ptrFromOffset(nameOffset),
- queryImports ? 1 : 0,
- outPtr
- )
- );
- const handle = this.lib.memory.loadPointer(outPtr);
- this.lib.recycleCallStack(stack);
- if (handle === 0) {
- throw Error("Cannot find function " + name);
- }
- const ret = this.makePackedFunc(handle);
- return ret;
+ return this.ctx.moduleGetFunction(this, name, queryImports) as PackedFunc;
}
/**
@@ -796,100 +728,16 @@ export class Module implements Disposable {
* @param mod The module to be imported.
*/
importModule(mod: Module): void {
- this.lib.checkCall(
- (this.lib.exports.TVMModImport as ctypes.FTVMModImport)(
- this.getHandle(),
- mod.getHandle()
- )
- );
+ this.ctx.moduleImport(this, mod);
}
}
-/**
- * Generic object base
- */
-export class TVMObject implements Disposable {
- private handle: Pointer;
- private lib: FFILibrary;
- protected ctx: RuntimeContext;
-
- constructor(
- handle: Pointer,
- lib: FFILibrary,
- ctx: RuntimeContext
- ) {
- this.handle = handle;
- this.lib = lib;
- this.ctx = ctx;
- }
-
- dispose(): void {
- if (this.handle != 0) {
- this.lib.checkCall(
- (this.lib.exports.TVMObjectFree as ctypes.FTVMObjectFree)(this.handle)
- );
- this.handle = 0;
- }
- }
-
- /**
- * Get handle of module, check it is not null.
- *
- * @param requireNotNull require handle is not null.
- * @returns The handle.
- */
- getHandle(requireNotNull = true): Pointer {
- if (requireNotNull && this.handle === 0) {
- throw Error("Module has already been disposed");
- }
- return this.handle;
- }
-
- /** get the type index of the object */
- typeIndex(): number {
- if (this.handle === 0) {
- throw Error("The current Object has already been disposed");
- }
- const stack = this.lib.getOrAllocCallStack();
- const outOffset = stack.allocPtrArray(1);
- const outPtr = stack.ptrFromOffset(outOffset);
-
- this.lib.checkCall(
- (this.lib.exports.TVMObjectGetTypeIndex as
ctypes.FTVMObjectGetTypeIndex)(
- this.getHandle(),
- outPtr
- )
- );
- const result = this.lib.memory.loadU32(outPtr);
- this.lib.recycleCallStack(stack);
- return result;
- }
-
- /** get the type key of the object */
- typeKey(): string {
- const type_index = this.typeIndex();
- const stack = this.lib.getOrAllocCallStack();
- const outOffset = stack.allocPtrArray(1);
- const outPtr = stack.ptrFromOffset(outOffset);
- this.lib.checkCall(
- (this.lib.exports.TVMObjectTypeIndex2Key as
ctypes.FTVMObjectTypeIndex2Key)(
- type_index,
- outPtr
- )
- );
- const result = this.lib.memory.loadCString(
- this.lib.memory.loadPointer(outPtr)
- );
- this.lib.recycleCallStack(stack);
- return result;
- }
-}
/** Objectconstructor */
type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx:
RuntimeContext) => TVMObject;
/** All possible object types. */
-type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc;
+type TVMObjectBase = TVMObject | PackedFunc;
/** Runtime array object. */
export class TVMArray extends TVMObject {
@@ -1212,38 +1060,16 @@ export class Instance implements Disposable {
* @returns The name list.
*/
listGlobalFuncNames(): Array<string> {
- const stack = this.lib.getOrAllocCallStack();
-
- const outSizeOffset = stack.allocPtrArray(2);
-
- const outSizePtr = stack.ptrFromOffset(outSizeOffset);
- const outArrayPtr = stack.ptrFromOffset(
- outSizeOffset + this.lib.sizeofPtr()
- );
-
- this.lib.checkCall(
- (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)(
- outSizePtr,
- outArrayPtr
- )
- );
-
- const size = this.memory.loadI32(outSizePtr);
- const array = this.memory.loadPointer(outArrayPtr);
- const names: Array<string> = [];
-
- for (let i = 0; i < size; ++i) {
- names.push(
- this.memory.loadCString(
- this.memory.loadPointer(array + this.lib.sizeofPtr() * i)
- )
- );
- }
-
- this.lib.recycleCallStack(stack);
- return names;
+ return this.withNewScope(() => {
+ const functor = this.ctx.functionListGlobalNamesFunctor();
+ const numNames = functor(new Scalar(-1, "int")) as number;
+ const names = new Array<string>(numNames);
+ for (let i = 0; i < numNames; i++) {
+ names[i] = functor(new Scalar(i, "int")) as string;
+ }
+ return names;
+ });
}
-
/**
* Register function to be global function in tvm runtime.
* @param name The name of the function.
@@ -1262,12 +1088,10 @@ export class Instance implements Disposable {
const ioverride = override ? 1 : 0;
const stack = this.lib.getOrAllocCallStack();
- const nameOffset = stack.allocRawBytes(name.length + 1);
- stack.storeRawBytes(nameOffset, StringToUint8Array(name));
+ const nameOffset = stack.allocByteArrayForString(name);
stack.commitToWasmMemory();
-
this.lib.checkCall(
- (this.lib.exports.TVMFuncRegisterGlobal as
ctypes.FTVMFuncRegisterGlobal)(
+ (this.lib.exports.TVMFFIFunctionSetGlobal as
ctypes.FTVMFFIFunctionSetGlobal)(
stack.ptrFromOffset(nameOffset),
packedFunc._tvmPackedCell.getHandle(),
ioverride
@@ -1289,15 +1113,14 @@ export class Instance implements Disposable {
private getGlobalFuncInternal(name: string, autoAttachToScope = true):
PackedFunc {
const stack = this.lib.getOrAllocCallStack();
- const nameOffset = stack.allocRawBytes(name.length + 1);
- stack.storeRawBytes(nameOffset, StringToUint8Array(name));
+ const nameOffset = stack.allocByteArrayForString(name);
const outOffset = stack.allocPtrArray(1);
const outPtr = stack.ptrFromOffset(outOffset);
stack.commitToWasmMemory(outOffset);
this.lib.checkCall(
- (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)(
+ (this.exports.TVMFFIFunctionGetGlobal as
ctypes.FTVMFFIFunctionGetGlobal)(
stack.ptrFromOffset(nameOffset),
outPtr
)
@@ -1335,7 +1158,7 @@ export class Instance implements Disposable {
private toPackedFuncInternal(func: Function, autoAttachToScope: boolean):
PackedFunc {
if (this.isPackedFunc(func)) return func as PackedFunc;
- const ret =
this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func));
+ const ret =
this.createPackedFuncFromSafeCallType(this.wrapJSFuncAsSafeCallType(func));
if (autoAttachToScope) return this.ctx.attachToCurrentScope(ret);
return ret;
}
@@ -1603,52 +1426,6 @@ export class Instance implements Disposable {
}
}
- /**
- * Convert dtype to {@link DLDataType}
- *
- * @param dtype The input dtype string or DLDataType.
- * @returns The converted result.
- */
- toDLDataType(dtype: string | DLDataType): DLDataType {
- if (dtype instanceof DLDataType) return dtype;
- if (typeof dtype === "string") {
- let pattern = dtype;
- let code,
- bits = 32,
- lanes = 1;
- if (pattern.substring(0, 5) === "float") {
- pattern = pattern.substring(5, pattern.length);
- code = DLDataTypeCode.Float;
- } else if (pattern.substring(0, 3) === "int") {
- pattern = pattern.substring(3, pattern.length);
- code = DLDataTypeCode.Int;
- } else if (pattern.substring(0, 4) === "uint") {
- pattern = pattern.substring(4, pattern.length);
- code = DLDataTypeCode.UInt;
- } else if (pattern.substring(0, 6) === "handle") {
- pattern = pattern.substring(5, pattern.length);
- code = DLDataTypeCode.OpaqueHandle;
- bits = 64;
- } else {
- throw new Error("Unknown dtype " + dtype);
- }
-
- const arr = pattern.split("x");
- if (arr.length >= 1) {
- const parsed = parseInt(arr[0]);
- if (parsed + "" === arr[0]) {
- bits = parsed;
- }
- }
- if (arr.length >= 2) {
- lanes = parseInt(arr[1]);
- }
- return new DLDataType(code, bits, lanes);
- } else {
- throw new Error("Unknown dtype " + dtype);
- }
- }
-
/**
* Create a new {@link Scalar} that can be passed to a PackedFunc.
* @param value The number value.
@@ -1698,36 +1475,8 @@ export class Instance implements Disposable {
dtype: string | DLDataType = "float32",
dev: DLDevice = this.device("cpu", 0)
): NDArray {
- dtype = this.toDLDataType(dtype);
shape = typeof shape === "number" ? [shape] : shape;
-
- const stack = this.lib.getOrAllocCallStack();
- const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64);
- for (let i = 0; i < shape.length; ++i) {
- stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]);
- }
-
- const outOffset = stack.allocPtrArray(1);
- const outPtr = stack.ptrFromOffset(outOffset);
- stack.commitToWasmMemory(outOffset);
-
- this.lib.checkCall(
- (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)(
- stack.ptrFromOffset(shapeOffset),
- shape.length,
- dtype.code,
- dtype.bits,
- dtype.lanes,
- dev.deviceType,
- dev.deviceId,
- outPtr
- )
- );
- const ret = this.ctx.attachToCurrentScope(
- new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx)
- );
- this.lib.recycleCallStack(stack);
- return ret;
+ return this.ctx.ndarrayEmpty(this.makeShapeTuple(shape), dtype, dev, null);
}
/**
@@ -1936,15 +1685,13 @@ export class Instance implements Disposable {
typeKey: string
): number {
const stack = this.lib.getOrAllocCallStack();
- const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1);
- stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey));
+ const typeKeyOffset = stack.allocByteArrayForString(typeKey);
const outOffset = stack.allocPtrArray(1);
const outPtr = stack.ptrFromOffset(outOffset);
stack.commitToWasmMemory(outOffset);
-
this.lib.checkCall(
- (this.lib.exports.TVMObjectTypeKey2Index as
ctypes.FTVMObjectTypeKey2Index)(
+ (this.lib.exports.TVMFFITypeKeyToIndex as ctypes.FTVMFFITypeKeyToIndex)(
stack.ptrFromOffset(typeKeyOffset),
outPtr
)
@@ -2153,6 +1900,10 @@ export class Instance implements Disposable {
(handle: number, lib: FFILibrary, ctx: RuntimeContext) => {
return new TVMArray(handle, lib, ctx);
});
+ this.registerObjectConstructor("runtime.Module",
+ (handle: number, lib: FFILibrary, ctx: RuntimeContext) => {
+ return new Module(handle, lib, ctx);
+ });
}
/** Register global packed functions needed by the backend to the env. */
@@ -2224,8 +1975,8 @@ export class Instance implements Disposable {
this.registerAsyncServerFunc("testing.asyncAddOne", addOne);
}
- private createPackedFuncFromCFunc(
- func: ctypes.FTVMWasmPackedCFunc
+ private createPackedFuncFromSafeCallType(
+ func: ctypes.FTVMFFIWasmSafeCallType
): PackedFunc {
let findex = this.env.packedCFuncTable.length;
if (this.env.packedCFuncTableFreeId.length != 0) {
@@ -2240,7 +1991,7 @@ export class Instance implements Disposable {
const outPtr = stack.ptrFromOffset(outOffset);
this.lib.checkCall(
(this.exports
- .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)(
+ .TVMFFIWasmFunctionCreate as ctypes.FTVMFFIWasmFunctionCreate)(
findex,
outPtr
)
@@ -2256,20 +2007,19 @@ export class Instance implements Disposable {
*
* @parma stack The call stack
* @param args The input arguments.
- * @param argsValue The offset of argsValue.
- * @param argsCode The offset of argsCode.
+ * @param packedArgs The offset of packedArgs.
*/
setPackedArguments(
stack: CachedCallStack,
args: Array<any>,
- argsValue: PtrOffset,
- argsCode: PtrOffset
+ packedArgs: PtrOffset,
): void {
for (let i = 0; i < args.length; ++i) {
let val = args[i];
const tp = typeof val;
- const valueOffset = argsValue + i * SizeOf.TVMValue;
- const codeOffset = argsCode + i * SizeOf.I32;
+ const argOffset = packedArgs + i * SizeOf.TVMFFIAny;
+ const argTypeIndexOffset = argOffset;
+ const argValueOffset = argOffset + SizeOf.I32 * 2;
// Convert string[] to a TVMArray of, hence treated as a TVMObject
if (val instanceof Array && val.every(e => typeof e === "string")) {
@@ -2278,97 +2028,100 @@ export class Instance implements Disposable {
val = this.makeTVMArray(tvmStringArray);
}
+ // clear off the extra padding valuesbefore ptr storage
+ stack.storeI32(argTypeIndexOffset + SizeOf.I32, 0);
+ stack.storeI32(argValueOffset + SizeOf.I32, 0);
if (val instanceof NDArray) {
if (!val.isView) {
- stack.storePtr(valueOffset, val.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINDArray);
+ stack.storePtr(argValueOffset, val.getHandle());
} else {
- stack.storePtr(valueOffset, val.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDLTensorPtr);
+ stack.storePtr(argValueOffset, val.getHandle());
}
} else if (val instanceof Scalar) {
if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) {
- stack.storeI64(valueOffset, val.value);
- stack.storeI32(codeOffset, ArgTypeCode.Int);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIInt);
+ stack.storeI64(argValueOffset, val.value);
} else if (val.dtype.startsWith("float")) {
- stack.storeF64(valueOffset, val.value);
- stack.storeI32(codeOffset, ArgTypeCode.Float);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat);
+ stack.storeF64(argValueOffset, val.value);
} else {
assert(val.dtype === "handle", "Expect handle");
- stack.storePtr(valueOffset, val.value);
- stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIOpaquePtr);
+ stack.storePtr(argValueOffset, val.value);
}
} else if (val instanceof DLDevice) {
- stack.storeI32(valueOffset, val.deviceType);
- stack.storeI32(valueOffset + SizeOf.I32, val.deviceType);
- stack.storeI32(codeOffset, ArgTypeCode.DLDevice);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDevice);
+ stack.storeI32(argValueOffset, val.deviceType);
+ stack.storeI32(argValueOffset + SizeOf.I32, val.deviceId);
+ } else if (tp === "boolean") {
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIBool);
+ stack.storeI64(argValueOffset, val ? 1 : 0);
} else if (tp === "number") {
- stack.storeF64(valueOffset, val);
- stack.storeI32(codeOffset, ArgTypeCode.Float);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat);
+ stack.storeF64(argValueOffset, val);
// eslint-disable-next-line no-prototype-builtins
} else if (tp === "function" && val.hasOwnProperty("_tvmPackedCell")) {
- stack.storePtr(valueOffset, val._tvmPackedCell.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle);
+ stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle());
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction);
} else if (val === null || val === undefined) {
- stack.storePtr(valueOffset, 0);
- stack.storeI32(codeOffset, ArgTypeCode.Null);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINone);
+ stack.storePtr(argValueOffset, 0);
} else if (tp === "string") {
- stack.allocThenSetArgString(valueOffset, val);
- stack.storeI32(codeOffset, ArgTypeCode.TVMStr);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIRawStr);
+ stack.allocThenSetArgString(argValueOffset, val);
} else if (val instanceof Uint8Array) {
- stack.allocThenSetArgBytes(valueOffset, val);
- stack.storeI32(codeOffset, ArgTypeCode.TVMBytes);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIByteArrayPtr);
+ stack.allocThenSetArgBytes(argValueOffset, val);
} else if (val instanceof Function) {
val = this.toPackedFuncInternal(val, false);
stack.tempArgs.push(val);
- stack.storePtr(valueOffset, val._tvmPackedCell.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction);
+ stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle());
} else if (val instanceof Module) {
- stack.storePtr(valueOffset, val.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle);
+ stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIModule);
+ stack.storePtr(argValueOffset, val.getHandle());
} else if (val instanceof TVMObject) {
- stack.storePtr(valueOffset, val.getHandle());
- stack.storeI32(codeOffset, ArgTypeCode.TVMObjectHandle);
+ stack.storeI32(argTypeIndexOffset, val.typeIndex());
+ stack.storePtr(argValueOffset, val.getHandle());
} else {
- throw new Error("Unsupported argument type " + tp);
+ throw new Error("Unsupported argument type " + tp + " value=`" +
val.toString() + "`");
}
}
}
- private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc {
+ private wrapJSFuncAsSafeCallType(func: Function):
ctypes.FTVMFFIWasmSafeCallType {
const lib = this.lib;
return (
- argValues: Pointer,
- argCodes: Pointer,
- nargs: number,
- ret: Pointer,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
- _handle: Pointer
+ self: Pointer,
+ packedArgs: Pointer,
+ numArgs: number,
+ ret: Pointer
): number => {
const jsArgs = [];
// use scope to track js values.
this.ctx.beginScope();
- for (let i = 0; i < nargs; ++i) {
- const valuePtr = argValues + i * SizeOf.TVMValue;
- const codePtr = argCodes + i * SizeOf.I32;
- let tcode = lib.memory.loadI32(codePtr);
-
- if (
- tcode === ArgTypeCode.TVMObjectHandle ||
- tcode === ArgTypeCode.TVMObjectRValueRefArg ||
- tcode === ArgTypeCode.TVMPackedFuncHandle ||
- tcode === ArgTypeCode.TVMNDArrayHandle ||
- tcode === ArgTypeCode.TVMModuleHandle
- ) {
+ for (let i = 0; i < numArgs; ++i) {
+ const argPtr = packedArgs + i * SizeOf.TVMFFIAny;
+ const typeIndex = lib.memory.loadI32(argPtr);
+
+ if (typeIndex >= TypeIndex.kTVMFFIRawStr) {
+ // NOTE: the following code have limitations in asyncify mode.
+ // The reason is that the TVMFFIAnyViewToOwnedAny will simply
+ // get skipped during the rewinding process, causing memory failure
+ if (!this.asyncifyHandler.isNormalStackState()) {
+ throw Error("Cannot handle str/object argument callback in
asyncify mode");
+ }
lib.checkCall(
- (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)(
- valuePtr,
- codePtr
+ (lib.exports.TVMFFIAnyViewToOwnedAny as
ctypes.FTVMFFIAnyViewToOwnedAny)(
+ argPtr,
+ argPtr
)
);
}
- tcode = lib.memory.loadI32(codePtr);
- jsArgs.push(this.retValueToJS(valuePtr, tcode, true));
+ jsArgs.push(this.retValueToJS(argPtr, true));
}
let rv: any;
@@ -2378,12 +2131,16 @@ export class Instance implements Disposable {
// error handling
// store error via SetLastError
this.ctx.endScope();
- const errMsg = "JSCallbackError: " + error.message;
+ const errKind = "JSCallbackError"
+ const errMsg = error.message;
const stack = lib.getOrAllocCallStack();
+ const errKindOffset = stack.allocRawBytes(errKind.length + 1);
+ stack.storeRawBytes(errKindOffset, StringToUint8Array(errKind));
const errMsgOffset = stack.allocRawBytes(errMsg.length + 1);
stack.storeRawBytes(errMsgOffset, StringToUint8Array(errMsg));
stack.commitToWasmMemory();
- (this.lib.exports.TVMAPISetLastError as ctypes.FTVMAPISetLastError)(
+ (this.lib.exports.FTVMFFIErrorSetRaisedByCStr as
ctypes.FTVMFFIErrorSetRaisedByCStr)(
+ stack.ptrFromOffset(errKindOffset),
stack.ptrFromOffset(errMsgOffset)
);
this.lib.recycleCallStack(stack);
@@ -2395,18 +2152,14 @@ export class Instance implements Disposable {
this.ctx.endScope();
if (rv !== undefined && rv !== null) {
const stack = lib.getOrAllocCallStack();
- const valueOffset = stack.allocRawBytes(SizeOf.TVMValue);
- const codeOffset = stack.allocRawBytes(SizeOf.I32);
- this.setPackedArguments(stack, [rv], valueOffset, codeOffset);
- const valuePtr = stack.ptrFromOffset(valueOffset);
- const codePtr = stack.ptrFromOffset(codeOffset);
+ const argOffset = stack.allocRawBytes(SizeOf.TVMFFIAny);
+ this.setPackedArguments(stack, [rv], argOffset);
stack.commitToWasmMemory();
+ const argPtr = stack.ptrFromOffset(argOffset);
lib.checkCall(
- (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)(
- ret,
- valuePtr,
- codePtr,
- 1
+ (lib.exports.TVMFFIAnyViewToOwnedAny as
ctypes.FTVMFFIAnyViewToOwnedAny)(
+ argPtr,
+ ret
)
);
lib.recycleCallStack(stack);
@@ -2416,38 +2169,25 @@ export class Instance implements Disposable {
}
private makePackedFunc(handle: Pointer): PackedFunc {
- const cell = new PackedFuncCell(handle, this.lib);
-
+ const cell = new PackedFuncCell(handle, this.lib, this.ctx);
const packedFunc = (...args: any): any => {
const stack = this.lib.getOrAllocCallStack();
-
- const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length);
- const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length);
-
- this.setPackedArguments(stack, args, valueOffset, tcodeOffset);
-
- const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue);
- const rcodeOffset = stack.allocRawBytes(SizeOf.I32);
- const rvaluePtr = stack.ptrFromOffset(rvalueOffset);
- const rcodePtr = stack.ptrFromOffset(rcodeOffset);
-
- // pre-store the rcode to be null, in case caller unwind
- // and not have chance to reset this rcode.
- stack.storeI32(rcodeOffset, ArgTypeCode.Null);
+ const argsOffset = stack.allocRawBytes(SizeOf.TVMFFIAny * args.length);
+ this.setPackedArguments(stack, args, argsOffset);
+ const retOffset = stack.allocRawBytes(SizeOf.TVMFFIAny);
+ // pre-store the result to be null
+ stack.storeI32(retOffset, TypeIndex.kTVMFFINone);
stack.commitToWasmMemory();
-
this.lib.checkCall(
- (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)(
+ (this.exports.TVMFFIFunctionCall as ctypes.FTVMFFIFunctionCall)(
cell.getHandle(),
- stack.ptrFromOffset(valueOffset),
- stack.ptrFromOffset(tcodeOffset),
+ stack.ptrFromOffset(argsOffset),
args.length,
- rvaluePtr,
- rcodePtr
+ stack.ptrFromOffset(retOffset)
)
);
- const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr),
false);
+ const ret = this.retValueToJS(stack.ptrFromOffset(retOffset), false);
this.lib.recycleCallStack(stack);
return ret;
};
@@ -2463,78 +2203,91 @@ export class Instance implements Disposable {
/**
* Creaye return value of the packed func. The value us auto-tracked for
dispose.
- * @param rvaluePtr The location of rvalue
- * @param tcode The type code.
+ * @param resultAnyPtr The location of rvalue
* @param callbackArg Whether it is being used in callbackArg.
* @returns The JS value.
*/
- private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg:
boolean): any {
- switch (tcode) {
- case ArgTypeCode.Int:
- case ArgTypeCode.UInt:
- case ArgTypeCode.TVMArgBool:
- return this.memory.loadI64(rvaluePtr);
- case ArgTypeCode.Float:
- return this.memory.loadF64(rvaluePtr);
- case ArgTypeCode.TVMOpaqueHandle: {
- return this.memory.loadPointer(rvaluePtr);
+ private retValueToJS(resultAnyPtr: Pointer, callbackArg: boolean): any {
+ const typeIndex = this.memory.loadI32(resultAnyPtr);
+ const valuePtr = resultAnyPtr + SizeOf.I32 * 2;
+ switch (typeIndex) {
+ case TypeIndex.kTVMFFINone: return undefined;
+ case TypeIndex.kTVMFFIBool:
+ return this.memory.loadI64(valuePtr) != 0;
+ case TypeIndex.kTVMFFIInt:
+ return this.memory.loadI64(valuePtr);
+ case TypeIndex.kTVMFFIFloat:
+ return this.memory.loadF64(valuePtr);
+ case TypeIndex.kTVMFFIOpaquePtr: {
+ return this.memory.loadPointer(valuePtr);
}
- case ArgTypeCode.TVMNDArrayHandle: {
+ case TypeIndex.kTVMFFINDArray: {
return this.ctx.attachToCurrentScope(
- new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib,
this.ctx)
+ new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx,
false)
);
}
- case ArgTypeCode.TVMDLTensorHandle: {
+ case TypeIndex.kTVMFFIDLTensorPtr: {
assert(callbackArg);
// no need to attach as we are only looking at view
- return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib,
this.ctx);
+ return new NDArray(this.memory.loadPointer(valuePtr), this.lib,
this.ctx, true);
}
- case ArgTypeCode.TVMPackedFuncHandle: {
+ case TypeIndex.kTVMFFIFunction: {
return this.ctx.attachToCurrentScope(
- this.makePackedFunc(this.memory.loadPointer(rvaluePtr))
+ this.makePackedFunc(this.memory.loadPointer(valuePtr))
);
}
- case ArgTypeCode.TVMModuleHandle: {
- return this.ctx.attachToCurrentScope(
- new Module(
- this.memory.loadPointer(rvaluePtr),
- this.lib,
- (ptr: Pointer) => {
- return this.ctx.attachToCurrentScope(this.makePackedFunc(ptr));
- }
- )
+ case TypeIndex.kTVMFFIDevice: {
+ const deviceType = this.memory.loadI32(valuePtr);
+ const deviceId = this.memory.loadI32(valuePtr + SizeOf.I32);
+ return this.device(deviceType, deviceId);
+ }
+ case TypeIndex.kTVMFFIDataType: {
+ // simply return dtype as tring to keep things simple
+ this.lib.checkCall(
+ (this.lib.exports.TVMFFIDataTypeToString as
ctypes.FTVMFFIDataTypeToString)(valuePtr, valuePtr)
+ );
+ const strObjPtr = this.memory.loadPointer(valuePtr);
+ const result = this.memory.loadByteArrayAsString(strObjPtr +
SizeOf.ObjectHeader);
+ this.lib.checkCall(
+ (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(strObjPtr)
+ );
+ return result;
+ }
+ case TypeIndex.kTVMFFIStr: {
+ const strObjPtr = this.memory.loadPointer(valuePtr);
+ const result = this.memory.loadByteArrayAsString(strObjPtr +
SizeOf.ObjectHeader);
+ this.lib.checkCall(
+ (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(strObjPtr)
);
+ return result;
}
- case ArgTypeCode.TVMObjectHandle: {
- const obj = new TVMObject(
- this.memory.loadPointer(rvaluePtr),
- this.lib,
- this.ctx
+ case TypeIndex.kTVMFFIBytes: {
+ const bytesObjPtr = this.memory.loadPointer(valuePtr);
+ const result = this.memory.loadByteArrayAsBytes(bytesObjPtr +
SizeOf.ObjectHeader);
+ this.lib.checkCall(
+ (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(bytesObjPtr)
);
- const func = this.objFactory.get(obj.typeIndex())
- if (func != undefined) {
- return this.ctx.attachToCurrentScope(
- func(obj.getHandle(), this.lib, this.ctx)
+ return result;
+ }
+ default: {
+ if (typeIndex >= TypeIndex.kTVMFFIStaticObjectBegin) {
+ const obj = new TVMObject(
+ this.memory.loadPointer(valuePtr),
+ this.lib,
+ this.ctx
);
+ const func = this.objFactory.get(obj.typeIndex())
+ if (func != undefined) {
+ return this.ctx.attachToCurrentScope(
+ func(obj.getHandle(), this.lib, this.ctx)
+ );
+ } else {
+ return this.ctx.attachToCurrentScope(obj);
+ }
} else {
- return this.ctx.attachToCurrentScope(obj);
+ throw new Error("Unsupported return type code=" + typeIndex);
}
}
- case ArgTypeCode.Null: return undefined;
- case ArgTypeCode.DLDevice: {
- const deviceType = this.memory.loadI32(rvaluePtr);
- const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32);
- return this.device(deviceType, deviceId);
- }
- case ArgTypeCode.TVMStr: {
- const ret =
this.memory.loadCString(this.memory.loadPointer(rvaluePtr));
- return ret;
- }
- case ArgTypeCode.TVMBytes: {
- return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr));
- }
- default:
- throw new Error("Unsupported return type code=" + tcode);
}
}
}
diff --git a/web/tests/node/test_ndarray.js b/web/tests/node/test_ndarray.js
index 8d369216d2..495d050701 100644
--- a/web/tests/node/test_ndarray.js
+++ b/web/tests/node/test_ndarray.js
@@ -38,7 +38,7 @@ function testArrayCopy(dtype, arrayType) {
let data = [1, 2, 3, 4, 5, 6];
let a = tvm.empty([2, 3], dtype).copyFrom(data);
- assert(a.device.toString() == "cpu(0)");
+ assert(a.device.toString() == "cpu:0");
assert(a.shape[0] == 2 && a.shape[1] == 3);
let ret = a.toArray();
diff --git a/web/tests/node/test_object.js b/web/tests/node/test_object.js
index 2423ef4ceb..3db3bd9c84 100644
--- a/web/tests/node/test_object.js
+++ b/web/tests/node/test_object.js
@@ -42,10 +42,5 @@ test("object", () => {
let t1 = b.get(1);
assert(t1.getHandle() == t.getHandle());
-
- let ret_string = tvm.getGlobalFunc("testing.ret_string");
- let s1 = ret_string("hello");
- assert(s1 == "hello");
- ret_string.dispose();
});
});
diff --git a/web/tests/node/test_packed_func.js
b/web/tests/node/test_packed_func.js
index e1d070f0e4..e2b6c7b7c9 100644
--- a/web/tests/node/test_packed_func.js
+++ b/web/tests/node/test_packed_func.js
@@ -37,7 +37,7 @@ let tvm = new tvmjs.Instance(
test("GetGlobal", () => {
tvm.beginScope();
let flist = tvm.listGlobalFuncNames();
- let faddOne = tvm.getGlobalFunc("testing.add_one");
+ let faddOne = tvm.getGlobalFunc("tvmjs.testing.add_one");
let fecho = tvm.getGlobalFunc("testing.echo");
assert(faddOne(tvm.scalar(1, "int")) == 2);
@@ -146,31 +146,6 @@ test("ExceptionPassing", () => {
tvm.endScope();
});
-
-test("AsyncifyFunc", async () => {
- if (!tvm.asyncifyEnabled()) {
- console.log("Skip asyncify tests as it is not enabled..");
- return;
- }
- tvm.beginScope();
- tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) {
- await new Promise(resolve => setTimeout(resolve, 10));
- return x;
- });
- let fecho = tvm.wrapAsyncifyPackedFunc(
- tvm.getGlobalFunc("async_sleep_echo")
- );
- let fcall = tvm.wrapAsyncifyPackedFunc(
- tvm.getGlobalFunc("testing.call")
- );
- assert((await fecho(1)) == 1);
- assert((await fecho(2)) == 2);
- assert((await fcall(fecho, 2) == 2));
- tvm.endScope();
- assert(fecho._tvmPackedCell.getHandle(false) == 0);
- assert(fcall._tvmPackedCell.getHandle(false) == 0);
-});
-
test("NDArrayCbArg", () => {
tvm.beginScope();
let use_count = tvm.getGlobalFunc("testing.object_use_count");
@@ -204,8 +179,32 @@ test("NDArrayCbArg", () => {
test("Logging", () => {
tvm.beginScope();
- const log_info = tvm.getGlobalFunc("testing.log_info_str");
+ const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str");
log_info("helow world")
log_info.dispose();
tvm.endScope();
});
+
+test("AsyncifyFunc", async () => {
+ if (!tvm.asyncifyEnabled()) {
+ console.log("Skip asyncify tests as it is not enabled..");
+ return;
+ }
+ tvm.beginScope();
+ tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) {
+ await new Promise(resolve => setTimeout(resolve, 10));
+ return x;
+ });
+ let fecho = tvm.wrapAsyncifyPackedFunc(
+ tvm.getGlobalFunc("async_sleep_echo")
+ );
+ let fcall = tvm.wrapAsyncifyPackedFunc(
+ tvm.getGlobalFunc("tvmjs.testing.call")
+ );
+ assert((await fecho(1)) == 1);
+ assert((await fecho(2)) == 2);
+ assert((await fcall(fecho, 2) == 2));
+ tvm.endScope();
+ assert(fecho._tvmPackedCell.getHandle(false) == 0);
+ assert(fcall._tvmPackedCell.getHandle(false) == 0);
+});
diff --git a/web/tests/python/webgpu_rpc_test.py
b/web/tests/python/webgpu_rpc_test.py
index e831afd9d3..8925da00a4 100644
--- a/web/tests/python/webgpu_rpc_test.py
+++ b/web/tests/python/webgpu_rpc_test.py
@@ -35,7 +35,6 @@ def test_rpc():
return
# generate the wasm library
target = tvm.target.Target("webgpu", host="llvm
-mtriple=wasm32-unknown-unknown-wasm")
- runtime = Runtime("cpp", {"system-lib": True})
n = te.var("n")
A = te.placeholder((n,), name="A")