This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit fac833560235e40b33b4b350b31fe8abebeec6aa Author: tqchen <[email protected]> AuthorDate: Sat May 3 08:57:23 2025 -0400 Fix up remaining items --- ffi/include/tvm/ffi/c_api.h | 4 ++-- ffi/src/ffi/function.cc | 4 ++-- src/relax/backend/contrib/cutlass/codegen.cc | 11 +++++++---- src/relax/backend/contrib/dnnl/codegen.cc | 2 +- src/runtime/hexagon/hexagon_common.cc | 2 +- src/runtime/hexagon/hexagon_device_api.cc | 12 ++++++------ web/emcc/tvmjs_support.cc | 8 +++++--- web/emcc/wasm_runtime.cc | 6 +++--- web/emcc/webgpu_runtime.cc | 2 +- 9 files changed, 28 insertions(+), 23 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 6569050b9e..c2bc31f0a0 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -399,8 +399,8 @@ TVM_FFI_DLL int TVMFFITypeKeyToIndex(const char* type_key, int32_t* out_tindex); * \param out The output of the function. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void* self), TVMFFIObjectHandle* out); +TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, + void (*deleter)(void* self), TVMFFIObjectHandle* out); /*! * \brief Convert a AnyView to an owned Any. diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc index 483c022566..5b009769ac 100644 --- a/ffi/src/ffi/function.cc +++ b/ffi/src/ffi/function.cc @@ -226,8 +226,8 @@ class EnvCAPIRegistry { } // namespace ffi } // namespace tvm -int TVMFFIFuncCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), - TVMFFIObjectHandle* out) { +int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), + TVMFFIObjectHandle* out) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::Function func = tvm::ffi::Function::FromExternC(self, safe_call, deleter); *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 4aa5000bbb..115efa6adb 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -69,7 +69,9 @@ runtime::Module Finalize(const std::string& code, const Array<String>& func_name const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.CSourceModuleCreate"); VLOG(1) << "Generated CUTLASS code:" << std::endl << code; - return pf(default_headers.str() + code, "cu", func_names, /*const_vars=*/Array<String>()); + return pf(default_headers.str() + code, "cu", func_names, + /*const_vars=*/Array<String>()) + .cast<runtime::Module>(); } class CodegenResultNode : public Object { @@ -129,7 +131,8 @@ GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& const auto instantiate_template_func = tvm::ffi::Function::GetGlobalRequired("contrib.cutlass.instantiate_template"); - CodegenResult codegen_res = instantiate_template_func(func_name, attrs, func_args); + CodegenResult codegen_res = + instantiate_template_func(func_name, attrs, func_args).cast<CodegenResult>(); ret.decl = codegen_res->code; ret.headers = codegen_res->headers; @@ -371,13 +374,13 @@ Array<runtime::Module> CUTLASSCompiler(Array<Function> functions, Map<String, ff << "The packed function contrib.cutlass.tune_relax_function not found, " "please import tvm.contrib.cutlass.build"; - Array<Function> annotated_functions = (*tune_func)(functions, options); + auto annotated_functions = (*tune_func)(functions, options).cast<Array<Function>>(); auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options); const auto pf = tvm::ffi::Function::GetGlobal("contrib.cutlass.compile"); ICHECK(pf.has_value()) << "The packed function contrib.cutlass.compile not found, please import " "tvm.contrib.cutlass.build"; - runtime::Module cutlass_mod = (*pf)(source_mod, options); + runtime::Module cutlass_mod = (*pf)(source_mod, options).cast<runtime::Module>(); return {cutlass_mod}; } diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc index d52823498c..ba219be4c6 100644 --- a/src/relax/backend/contrib/dnnl/codegen.cc +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -91,7 +91,7 @@ Array<runtime::Module> DNNLCompiler(Array<Function> functions, Map<String, ffi:: auto constant_names = serializer.GetConstantNames(); const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.DNNLJSONRuntimeCreate"); auto func_name = GetExtSymbol(func); - compiled_functions.push_back(pf(func_name, graph_json, constant_names)); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast<runtime::Module>()); } return compiled_functions; diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index 60a4ab3b4c..69be065210 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -91,7 +91,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") .set_body_packed([](TVMArgs args, TVMRetValue* rv) { - ObjectPtr<Library> n = CreateDSOLibraryObject(args[0]); + ObjectPtr<Library> n = CreateDSOLibraryObject(args[0].cast<String>()); *rv = CreateModuleFromLibrary(n); }); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index e17779c58f..1ebe83b5a4 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -211,10 +211,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") .set_body_packed([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast<int>(args[0]); + uint32_t queue_id = args[0].cast<uint32_t>(); void* dst = args[1].cast<void*>(); void* src = args[2].cast<void*>(); - uint32_t size = static_cast<int>(args[3]); + uint32_t size = args[3].cast<uint32_t>(); ICHECK(size > 0); bool bypass_cache = args[4].cast<bool>(); @@ -228,7 +228,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") .set_body_packed([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast<int>(args[0]); + uint32_t queue_id = args[0].cast<uint32_t>(); int inflight = args[1].cast<int>(); ICHECK(inflight >= 0); HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight); @@ -237,14 +237,14 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") .set_body_packed([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast<int>(args[0]); + uint32_t queue_id = args[0].cast<uint32_t>(); HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); *rv = static_cast<int32_t>(0); }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") .set_body_packed([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast<int>(args[0]); + uint32_t queue_id = args[0].cast<uint32_t>(); HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); *rv = static_cast<int32_t>(0); }); @@ -259,7 +259,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") CHECK(scope.find("global.vtcm") != std::string::npos); int64_t ndim = args[5].cast<int64_t>(); CHECK((ndim == 1 || ndim == 2) && "Hexagon Device API supports only 1d and 2d allocations"); - int64_t* shape = static_cast<int64_t*>(static_cast<void*>(args[6])); + int64_t* shape = static_cast<int64_t*>(args[6].cast<void*>()); Device dev; dev.device_type = static_cast<DLDeviceType>(device_type); diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 2c4ba7da08..fd6265f6c0 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -157,8 +157,10 @@ class AsyncLocalSession : public LocalSession { } else if (func == get_time_eval_placeholder_.get()) { // special handle time evaluator. try { - PackedFunc retfunc = this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], - args[5], args[6], args[7], args[8], args[9]); + PackedFunc retfunc = this->GetTimeEvaluator( + args[0].cast<ffi::Optional<Module>>(), args[1].cast<std::string>(), args[2].cast<int>(), + args[3].cast<int>(), args[4].cast<int>(), args[5].cast<int>(), args[6].cast<int>(), + args[7].cast<int>(), args[8].cast<int>(), args[9].cast<int>()); TVMRetValue rv; rv = retfunc; this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { @@ -269,7 +271,7 @@ class AsyncLocalSession : public LocalSession { auto ftimer = [pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown](TVMArgs args, TVMRetValue* rv) { // the function is a async function. - PackedFunc on_complete = args[args.size() - 1]; + PackedFunc on_complete = args[args.size() - 1].cast<PackedFunc>(); std::vector<AnyView> packed_args(args.data(), args.data() + args.size() - 1); auto finvoke = [pf, packed_args](int n) { diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 0873101bbb..7856153f97 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -107,7 +107,7 @@ TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](TVMArgs args, TVMRetValue }); TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - (args[0].operator PackedFunc()).CallPacked(args.Slice(1), ret); + (args[0].cast<PackedFunc>()).CallPacked(args.Slice(1), ret); }); TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](TVMArgs args, TVMRetValue* ret) { @@ -125,7 +125,7 @@ TVM_REGISTER_GLOBAL("testing.log_fatal_str").set_body_packed([](TVMArgs args, TV TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); TVM_REGISTER_GLOBAL("testing.wrap_callback").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - PackedFunc pf = args[0]; + PackedFunc pf = args[0].cast<PackedFunc>(); *ret = runtime::TypedPackedFunc<void()>([pf]() { pf(); }); }); @@ -215,7 +215,7 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") .set_body_packed([](TVMArgs args, TVMRetValue* ret) { std::vector<NDArray> embeddings; for (int i = 0; i < args.size(); ++i) { - embeddings.push_back(args[i].operator NDArray()); + embeddings.push_back(args[i].cast<NDArray>()); } NDArray result = ConcatEmbeddings(std::move(embeddings)); *ret = result; diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 5b064cfc16..1aafc272c3 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -184,7 +184,7 @@ class WebGPUModuleNode final : public runtime::ModuleNode { } else if (name == "webgpu.update_prebuild") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { auto name = args[0].cast<std::string>(); - PackedFunc func = args[1]; + PackedFunc func = args[1].cast<PackedFunc>(); prebuild_[name] = func; }); }
