This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 21fa6e5a9b32daca0b6a4d9cdb62b6033d1f63f4 Author: tqchen <[email protected]> AuthorDate: Mon Apr 28 16:53:55 2025 -0400 [FFI] More global function access utils --- ffi/include/tvm/ffi/container/map.h | 15 +++++++-------- ffi/include/tvm/ffi/function.h | 37 ++++++++++++++++++++++++++++++++++++- ffi/tests/cpp/test_map.cc | 4 ++-- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h index 31b630f9f8..116c9930f1 100644 --- a/ffi/include/tvm/ffi/container/map.h +++ b/ffi/include/tvm/ffi/container/map.h @@ -215,8 +215,7 @@ class MapObj : public Object { }; /*! \brief A specialization of small-sized hash map */ -class SmallMapObj : public MapObj, - public details::InplaceArrayBase<SmallMapObj, MapObj::KVType> { +class SmallMapObj : public MapObj, public details::InplaceArrayBase<SmallMapObj, MapObj::KVType> { private: static constexpr uint64_t kInitSize = 2; static constexpr uint64_t kMaxSize = 4; @@ -1005,10 +1004,10 @@ class DenseMapObj : public MapObj { #define TVM_DISPATCH_MAP(base, var, body) \ { \ - using TSmall = SmallMapObj*; \ - using TDense = DenseMapObj*; \ + using TSmall = SmallMapObj*; \ + using TDense = DenseMapObj*; \ uint64_t slots = base->slots_; \ - if (slots <= SmallMapObj::kMaxSize) { \ + if (slots <= SmallMapObj::kMaxSize) { \ TSmall var = static_cast<TSmall>(base); \ body; \ } else { \ @@ -1019,10 +1018,10 @@ class DenseMapObj : public MapObj { #define TVM_DISPATCH_MAP_CONST(base, var, body) \ { \ - using TSmall = const SmallMapObj*; \ - using TDense = const DenseMapObj*; \ + using TSmall = const SmallMapObj*; \ + using TDense = const DenseMapObj*; \ uint64_t slots = base->slots_; \ - if (slots <= SmallMapObj::kMaxSize) { \ + if (slots <= SmallMapObj::kMaxSize) { \ TSmall var = static_cast<TSmall>(base); \ body; \ } else { \ diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 6833e8b52b..d53bb01934 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -385,6 +385,13 @@ class Function : public ObjectRef { return std::nullopt; } } + + static std::optional<Function> GetGlobal(const std::string& name) { + return GetGlobal(name.c_str()); + } + + static std::optional<Function> GetGlobal(const String& name) { return GetGlobal(name.c_str()); } + /*! * \brief Get global function by name and throw an error if it is not found. * \param name The name of the function @@ -398,6 +405,13 @@ class Function : public ObjectRef { } return res.value(); } + + static Function GetGlobalRequired(const std::string& name) { + return GetGlobalRequired(name.c_str()); + } + + static Function GetGlobalRequired(const String& name) { return GetGlobalRequired(name.c_str()); } + /*! * \brief Set global function by name * \param name The name of the function @@ -408,7 +422,28 @@ class Function : public ObjectRef { TVM_FFI_CHECK_SAFE_CALL( TVMFFIFuncSetGlobal(name, details::ObjectUnsafe::GetHeader(func.get()), override)); } - + /*! + * \brief List all global names + * \return A vector of all global names + * \note This function do not depend on Array so core do not have container dep. + */ + static std::vector<String> ListGlobalNames() { + Function fname_functor = GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")(); + std::vector<String> names; + int len = fname_functor(-1); + for (int i = 0; i < len; ++i) { + names.push_back(fname_functor(i)); + } + return names; + } + /** + * \brief Remove a global function by name + * \param name The name of the function + */ + static void RemoveGlobal(const String& name) { + static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); + fremove(name); + } /*! * \brief Constructing a packed function from a normal function. * diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc index 282e4c8319..a4617a221e 100644 --- a/ffi/tests/cpp/test_map.cc +++ b/ffi/tests/cpp/test_map.cc @@ -244,8 +244,8 @@ TEST(Map, AnyConvertCheck) { } TEST(Map, PackedFuncGetItem) { - Function f = Function::FromUnpacked( - [](const MapObj* n, const Any& k) -> Any { return n->at(k); }, "map_get_item"); + Function f = Function::FromUnpacked([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, + "map_get_item"); Map<String, int64_t> map{{"x", 1}, {"y", 2}}; Any k("x"); Any v = f(map, k);
