This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8fcd924  feat: Introduce `tvm.registry.get_registered_type_keys()` 
(#249)
8fcd924 is described below

commit 8fcd9245186df3d6570e641dfc1c84239a9f9a40
Author: Junru Shao <[email protected]>
AuthorDate: Sun Nov 9 15:03:01 2025 -0800

    feat: Introduce `tvm.registry.get_registered_type_keys()` (#249)
    
    This API is going to be useful for library stub generation, where we
    need to query objects registered on C++ end.
---
 python/tvm_ffi/_ffi_api.py  |  1 +
 python/tvm_ffi/registry.py  | 17 +++++++++++++++--
 src/ffi/object.cc           | 21 +++++++++++++++++++++
 tests/python/test_object.py | 30 +++++++++++++++++++++++++++++-
 4 files changed, 66 insertions(+), 3 deletions(-)

diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 5a957d6..316a6fb 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
     def FunctionRemoveGlobal(_0: str, /) -> bool: ...
     def GetFirstStructuralMismatch(_0: Any, _1: Any, _2: bool, _3: bool, /) -> 
tuple[AccessPath, AccessPath] | None: ...
     def GetGlobalFuncMetadata(_0: str, /) -> str: ...
+    def GetRegisteredTypeKeys() -> Sequence[str]: ...
     def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
     def Map(*args: Any) -> Any: ...
     def MapCount(_0: Mapping[Any, Any], _1: Any, /) -> int: ...
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index ce1d15f..74f56cb 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import json
 import sys
-from typing import Any, Callable, Literal, TypeVar, overload
+from typing import Any, Callable, Literal, Sequence, TypeVar, overload
 
 from . import core
 from .core import TypeInfo
@@ -268,7 +268,7 @@ def get_global_func_metadata(name: str) -> dict[str, Any]:
         Register a Python callable as a global FFI function.
 
     """
-    return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name))
+    return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name) or 
"{}")
 
 
 def init_ffi_api(namespace: str, target_module_name: str | None = None) -> 
None:
@@ -346,9 +346,22 @@ def __init__invalid(self: Any, *args: Any, **kwargs: Any) 
-> None:
     raise RuntimeError("The __init__ method of this class is not implemented.")
 
 
+def get_registered_type_keys() -> Sequence[str]:
+    """Get the list of valid type keys registered to TVM-FFI.
+
+    Returns
+    -------
+    type_keys
+        List of valid type keys.
+
+    """
+    return get_global_func("ffi.GetRegisteredTypeKeys")()
+
+
 __all__ = [
     "get_global_func",
     "get_global_func_metadata",
+    "get_registered_type_keys",
     "init_ffi_api",
     "list_global_func_names",
     "register_global_func",
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 1671ba8..e8a232d 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -21,6 +21,7 @@
  * \brief Registry to record dynamic types
  */
 #include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/function.h>
@@ -193,6 +194,16 @@ class TypeTable {
     return entry;
   }
 
+  Array<String> GetRegisteredTypeKeys() const {
+    Array<String> ret;
+    for (const auto& entry : type_table_) {
+      if (entry) {
+        ret.push_back(entry->type_key_data);
+      }
+    }
+    return ret;
+  }
+
   void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) {
     Entry* entry = GetTypeEntry(type_index);
     TVMFFIFieldInfo field_data = *info;
@@ -537,3 +548,13 @@ int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input, 
TVMFFIAny* out) {
   
tvm::ffi::TypeTraits<tvm::ffi::Bytes>::MoveToAny(tvm::ffi::Bytes(input->data, 
input->size), out);
   TVM_FFI_SAFE_CALL_END();
 }
+
+namespace {
+TVM_FFI_STATIC_INIT_BLOCK() {
+  using namespace tvm::ffi;
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def_method("ffi.GetRegisteredTypeKeys", []() -> 
Array<String> {
+    return tvm::ffi::TypeTable::Global()->GetRegisteredTypeKeys();
+  });
+}
+}  // namespace
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index c39b0be..e49d3ec 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import sys
-from typing import Any
+from typing import Any, Sequence
 
 import pytest
 import tvm_ffi
@@ -228,3 +228,31 @@ def test_type_info_attachment(test_cls: type, type_key: 
str, parent_cls: type |
         assert parent_type_info.type_cls is parent_cls, (
             f"Expected parent type {parent_cls}, but got 
{parent_type_info.type_cls}"
         )
+
+
+def test_get_registered_type_keys() -> None:
+    keys = tvm_ffi.registry.get_registered_type_keys()
+    assert isinstance(keys, Sequence)
+    assert all(isinstance(k, str) for k in keys)
+    keys = set(keys)
+    assert "ffi.Object" in keys
+    assert "ffi.String" in keys
+    for ty in [
+        "None",
+        "int",
+        "bool",
+        "float",
+        "void*",
+        "DataType",
+        "Device",
+        "DLTensor*",
+        "const char*",
+        "TVMFFIByteArray*",
+        "ObjectRValueRef",
+    ]:
+        assert ty in keys, f"Expected to find `{ty}` in registered type keys, 
but it was not found."
+        keys.remove(ty)
+    for ty in keys:
+        assert ty.startswith("ffi.") or ty.startswith("testing."), (
+            f"Expected type key `{ty}` to start with `ffi.` or `testing.`"
+        )

Reply via email to