This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 0e8cdf2ccfdc13c63141d5e8d093b486b213caff Author: tqchen <[email protected]> AuthorDate: Sun Apr 13 11:54:33 2025 -0400 fix through web --- web/emcc/tvmjs_support.cc | 65 ++++++++++++++++------------------------------- web/emcc/wasm_runtime.cc | 21 +++++++-------- web/src/runtime.ts | 4 +-- 3 files changed, 35 insertions(+), 55 deletions(-) diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index ee00e390a5..970a1d9d31 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -107,22 +107,12 @@ class AsyncLocalSession : public LocalSession { if (name == "runtime.RPCTimeEvaluator") { return get_time_eval_placeholder_.get(); } else if (auto* fp = tvm::runtime::Registry::Get(name)) { - // return raw handle because the remote need to explicitly manage it. - tvm::runtime::TVMRetValue ret; - ret = *fp; - TVMValue val; - int type_code; - ret.MoveToCHost(&val, &type_code); - return val.v_handle; + TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(tvm::ffi::Any(*fp)); + return val.v_obj; } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) { - tvm::runtime::TVMRetValue ret; - ret = *fp; - TVMValue val; - int type_code; - ret.MoveToCHost(&val, &type_code); - auto* rptr = val.v_handle; - async_func_set_.insert(rptr); - return rptr; + TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(tvm::ffi::Any(*fp)); + async_func_set_.insert(val.v_obj); + return val.v_obj; } else { return nullptr; } @@ -140,8 +130,7 @@ class AsyncLocalSession : public LocalSession { } } - void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, - int num_args, FAsyncCallback callback) final { + void AsyncCallFunc(PackedFuncHandle func, ffi::PackedArgs args, FAsyncCallback callback) final { auto it = async_func_set_.find(func); if (it != async_func_set_.end()) { PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) { @@ -159,36 +148,31 @@ class AsyncLocalSession : public LocalSession { } }); - TVMRetValue temp; - std::vector<TVMValue> values(arg_values, arg_values + num_args); - std::vector<int> type_codes(arg_type_codes, arg_type_codes + num_args); - values.emplace_back(TVMValue()); - type_codes.emplace_back(0); - - TVMArgsSetter setter(&values[0], &type_codes[0]); + std::vector<AnyView> packed_args(args.data(), args.data() + args.size()); // pass the callback as the last argument. - setter(num_args, packed_callback); - - auto* pf = static_cast<PackedFuncObj*>(func); - pf->CallPacked(TVMArgs(values.data(), type_codes.data(), num_args + 1), &temp); + packed_args.emplace_back(AnyView(packed_callback)); + auto* pf = static_cast<ffi::FunctionObj*>(func); + Any temp; + pf->CallPacked(packed_args.data(), packed_args.size(), &temp); } else if (func == get_time_eval_placeholder_.get()) { // special handle time evaluator. try { - TVMArgs args(arg_values, arg_type_codes, num_args); PackedFunc retfunc = this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9]); TVMRetValue rv; rv = retfunc; this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + const void* pf = encoded_args[0].as<ffi::FunctionObj>(); + ICHECK(pf != nullptr); // mark as async. - async_func_set_.insert(encoded_args.values[1].v_handle); + async_func_set_.insert(const_cast<void*>(pf)); callback(RPCCode::kReturn, encoded_args); }); } catch (const std::runtime_error& e) { this->SendException(callback, e.what()); } } else { - LocalSession::AsyncCallFunc(func, arg_values, arg_type_codes, num_args, callback); + LocalSession::AsyncCallFunc(func, args, callback); } } @@ -230,10 +214,9 @@ class AsyncLocalSession : public LocalSession { void AsyncStreamWait(Device dev, TVMStreamHandle stream, FAsyncCallback on_complete) final { if (dev.device_type == kDLCPU) { - TVMValue value; - int32_t tcode = kTVMNullptr; - value.v_handle = nullptr; - on_complete(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + AnyView packed_args[1]; + packed_args[0] = nullptr; + on_complete(RPCCode::kReturn, ffi::PackedArgs(packed_args, 1)); } else { CHECK(dev.device_type == static_cast<DLDeviceType>(kDLWebGPU)); if (async_wait_ == nullptr) { @@ -242,8 +225,7 @@ class AsyncLocalSession : public LocalSession { CHECK(async_wait_ != nullptr); PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) { int code = args[0]; - on_complete(static_cast<RPCCode>(code), - TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1)); + on_complete(static_cast<RPCCode>(code), args.Slice(1)); }); (*async_wait_)(packed_callback); } @@ -288,14 +270,11 @@ class AsyncLocalSession : public LocalSession { cooldown_interval_ms, repeats_to_cooldown](TVMArgs args, TVMRetValue* rv) { // the function is a async function. PackedFunc on_complete = args[args.size() - 1]; - // keep argument alive in finvoke so that they - // can be used throughout the async benchmark - std::vector<TVMValue> values(args.values, args.values + args.size() - 1); - std::vector<int> type_codes(args.type_codes, args.type_codes + args.size() - 1); - auto finvoke = [pf, values, type_codes](int n) { + std::vector<AnyView> packed_args(args.data(), args.data() + args.size() - 1); + auto finvoke = [pf, packed_args](int n) { TVMRetValue temp; - TVMArgs invoke_args(values.data(), type_codes.data(), values.size()); + TVMArgs invoke_args(packed_args.data(), packed_args.size()); for (int i = 0; i < n; ++i) { pf.CallPacked(invoke_args, &temp); } diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 67255f254b..bbc427e60f 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -26,7 +26,7 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 - +#define TVM_FFI_USE_LIBBACKTRACE 0 #define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h> #include <tvm/runtime/c_runtime_api.h> @@ -66,6 +66,10 @@ #include "src/runtime/relax_vm/rnn_state.cc" #include "src/runtime/relax_vm/vm.cc" +#include "ffi/src/ffi/object.cc" +#include "ffi/src/ffi/function.cc" +#include "ffi/src/ffi/traceback.cc" + // --- Implementations of backend and wasm runtime API. --- int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { @@ -105,7 +109,7 @@ TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) TVM_REGISTER_GLOBAL("testing.call").set_body([](TVMArgs args, TVMRetValue* ret) { (args[0].operator PackedFunc()) - .CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1), ret); + .CallPacked(args.Slice(1), ret); }); TVM_REGISTER_GLOBAL("testing.ret_string").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -160,19 +164,17 @@ TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStor // Concatenate n TVMArrays TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector<ObjectRef> data; + std::vector<Any> data; for (int i = 0; i < args.size(); ++i) { // Get i-th TVMArray - ICHECK_EQ(args[i].type_code(), kTVMObjectHandle); - Object* ptr = static_cast<Object*>(args[i].value().v_handle); - ICHECK(ptr->IsInstance<ArrayNode>()); - auto* arr_i = static_cast<const ArrayNode*>(ptr); + auto* arr_i = args[i].as<ArrayNode>(); + ICHECK(arr_i != nullptr); for (size_t j = 0; j < arr_i->size(); ++j) { // Push back each j-th element of the i-th array data.push_back(arr_i->at(j)); } } - *ret = Array<ObjectRef>(data); + *ret = Array<Any>(data); }); NDArray ConcatEmbeddings(const std::vector<NDArray>& embeddings) { @@ -213,8 +215,7 @@ NDArray ConcatEmbeddings(const std::vector<NDArray>& embeddings) { TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings").set_body([](TVMArgs args, TVMRetValue* ret) { std::vector<NDArray> embeddings; for (int i = 0; i < args.size(); ++i) { - ICHECK_EQ(args[i].type_code(), kTVMNDArrayHandle); - embeddings.push_back(args[i]); + embeddings.push_back(args[i].operator NDArray()); } NDArray result = ConcatEmbeddings(std::move(embeddings)); *ret = result; diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 8546cab773..d14ae663b2 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -2183,11 +2183,11 @@ export class Instance implements Disposable { /** Register all object factory */ private registerObjectFactoryFuncs(): void { - this.registerObjectConstructor("Array", + this.registerObjectConstructor("object.Array", (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMArray(handle, lib, ctx); }); - this.registerObjectConstructor("runtime.String", + this.registerObjectConstructor("object.String", (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMString(handle, lib, ctx); });
