This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 785e8ca  feat: Add typestubs for Cython generated code (#24)
785e8ca is described below

commit 785e8ca1366b718cb95ae5521fbd39a06a36b778
Author: Junru Shao <[email protected]>
AuthorDate: Thu Sep 18 19:52:55 2025 -0700

    feat: Add typestubs for Cython generated code (#24)
    
    After this PR, if Pylance still reports missing imports `core`, you may
    set:
    
    ```
    "python.analysis.extraPaths": ["./build/"]
    ```
    
    in base directory.
---
 python/tvm_ffi/_tensor.py      |  18 +++-
 python/tvm_ffi/core.pyi        | 227 +++++++++++++++++++++++++++++++++++++++++
 tests/lint/check_asf_header.py |   1 +
 3 files changed, 241 insertions(+), 5 deletions(-)

diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index 903a69d..8d06bd2 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -21,12 +21,20 @@
 from numbers import Integral
 from typing import Any, Optional, Union
 
-from . import _ffi_api, core, registry
-from .core import Device, DLDeviceType, Tensor, from_dlpack
+from . import _ffi_api, registry
+from .core import (
+    _CLASS_DEVICE,
+    Device,
+    DLDeviceType,
+    PyNativeObject,
+    Tensor,
+    _shape_obj_get_py_tuple,
+    from_dlpack,
+)
 
 
 @registry.register_object("ffi.Shape")
-class Shape(tuple, core.PyNativeObject):
+class Shape(tuple, PyNativeObject):
     """Shape tuple that represents `ffi::Shape` returned by a ffi call.
 
     Note:
@@ -46,7 +54,7 @@ class Shape(tuple, core.PyNativeObject):
     # pylint: disable=no-self-argument
     def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape":
         """Construct from a given tvm object."""
-        content = core._shape_obj_get_py_tuple(obj)
+        content = _shape_obj_get_py_tuple(obj)
         val = tuple.__new__(cls, content)
         val.__tvm_ffi_object__ = obj
         return val
@@ -78,7 +86,7 @@ def device(device_type: Union[str, int, DLDeviceType], index: 
Optional[int] = No
       assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0)
 
     """
-    return core._CLASS_DEVICE(device_type, index)
+    return _CLASS_DEVICE(device_type, index)
 
 
 __all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"]
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
new file mode 100644
index 0000000..d57d020
--- /dev/null
+++ b/python/tvm_ffi/core.pyi
@@ -0,0 +1,227 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Typestubs for Cython."""
+
+from __future__ import annotations
+
+import types
+from enum import IntEnum
+from typing import Any, Callable
+
+# Public module-level variables referenced by Python code
+ERROR_NAME_TO_TYPE: dict[str, type]
+ERROR_TYPE_TO_NAME: dict[type, str]
+
+_WITH_APPEND_TRACEBACK: Callable[[BaseException, str], BaseException] | None
+_TRACEBACK_TO_STR: Callable[[types.TracebackType | None], str] | None
+
+# DLPack protocol version (defined in tensor.pxi)
+__dlpack_version__: tuple[int, int]
+
+class Object:
+    """Base class of all TVM FFI objects."""
+
+    def __ctypes_handle__(self) -> Any: ...
+    def __chandle__(self) -> int: ...
+    def __reduce__(self) -> Any: ...
+    def __getstate__(self) -> dict[str, Any]: ...
+    def __setstate__(self, state: dict[str, Any]) -> None: ...
+    def __repr__(self) -> str: ...
+    def __eq__(self, other: Any) -> bool: ...
+    def __ne__(self, other: Any) -> bool: ...
+    def __hash__(self) -> int: ...
+    def __init_handle_by_constructor__(self, fconstructor: Function, *args: 
Any) -> None: ...
+    def same_as(self, other: Any) -> bool: ...
+    def _move(self) -> ObjectRValueRef: ...
+    def __move_handle_from__(self, other: Object) -> None: ...
+
+class ObjectConvertible:
+    """Base class for all classes that can be converted to Object."""
+
+    def asobject(self) -> Object: ...
+
+class ObjectRValueRef:
+    """Represent an RValue ref to an object that can be moved."""
+
+    obj: Object
+    def __init__(self, obj: Object) -> None: ...
+
+class OpaquePyObject(Object):
+    """Opaque PyObject container."""
+
+    def pyobject(self) -> Any: ...
+
+class PyNativeObject:
+    """Base class of all TVM objects that also subclass python's builtin 
types."""
+
+    __slots__: list[str]
+    def __init_tvm_ffi_object_by_constructor__(
+        self, fconstructor: Function, *args: Any
+    ) -> None: ...
+
+def _set_class_object(cls: type) -> None: ...
+def _register_object_by_index(index: int, cls: type) -> None: ...
+def _object_type_key_to_index(type_key: str) -> int | None: ...
+def _add_class_attrs_by_reflection(type_index: int, cls: type) -> type: ...
+
+class Error(Object):
+    """Base class for FFI errors."""
+
+    def __init__(self, kind: str, message: str, traceback: str) -> None: ...
+    def update_traceback(self, traceback: str) -> None: ...
+    def py_error(self) -> BaseException: ...
+    @property
+    def kind(self) -> str: ...
+    @property
+    def message(self) -> str: ...
+    @property
+    def traceback(self) -> str: ...
+
+def _convert_to_ffi_error(error: BaseException) -> Error: ...
+def _env_set_current_stream(device_type: int, device_id: int, stream: int) -> 
int: ...
+
+class DataType:
+    """DataType wrapper around DLDataType."""
+
+    def __init__(self, dtype_str: str) -> None: ...
+    def __reduce__(self) -> Any: ...
+    def __eq__(self, other: Any) -> bool: ...
+    def __ne__(self, other: Any) -> bool: ...
+    @property
+    def type_code(self) -> int: ...
+    @property
+    def bits(self) -> int: ...
+    @property
+    def lanes(self) -> int: ...
+    @property
+    def itemsize(self) -> int: ...
+    def __str__(self) -> str: ...
+
+def _set_class_dtype(cls: type) -> None: ...
+def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes: 
int) -> DataType: ...
+
+class DLDeviceType(IntEnum):
+    """Enum that maps to DLDeviceType."""
+
+    kDLCPU = 1
+    kDLCUDA = 2
+    kDLCUDAHost = 3
+    kDLOpenCL = 4
+    kDLVulkan = 7
+    kDLMetal = 8
+    kDLVPI = 9
+    kDLROCM = 10
+    kDLROCMHost = 11
+    kDLExtDev = 12
+    kDLCUDAManaged = 13
+    kDLOneAPI = 14
+    kDLWebGPU = 15
+    kDLHexagon = 16
+
+class Device:
+    """Device represents a device in the ffi system."""
+
+    def __init__(self, device_type: str | int, index: int | None = None) -> 
None: ...
+    def __reduce__(self) -> Any: ...
+    def __eq__(self, other: Any) -> bool: ...
+    def __ne__(self, other: Any) -> bool: ...
+    def __str__(self) -> str: ...
+    def __repr__(self) -> str: ...
+    def __hash__(self) -> int: ...
+    def __device_type_name__(self) -> str: ...
+    @property
+    def type(self) -> str: ...
+    @property
+    def index(self) -> int: ...
+    def dlpack_device_type(self) -> int: ...
+
+def _set_class_device(cls: type) -> None: ...
+
+_CLASS_DEVICE: type[Device]
+
+def _shape_obj_get_py_tuple(obj: Any) -> tuple[int, ...]: ...
+
+class Tensor(Object):
+    """Tensor object that represents a managed n-dimensional array."""
+
+    @property
+    def shape(self) -> tuple[int, ...]: ...
+    @property
+    def dtype(self) -> Any: ...  # returned as python dtype (str subclass)
+    @property
+    def device(self) -> Device: ...
+    def _to_dlpack(self) -> Any: ...
+    def _to_dlpack_versioned(self) -> Any: ...
+    def __dlpack_device__(self) -> tuple[int, int]: ...
+    def __dlpack__(
+        self,
+        *,
+        stream: Any | None = None,
+        max_version: tuple[int, int] | None = None,
+        dl_device: tuple[int, int] | None = None,
+        copy: bool | None = None,
+    ) -> Any: ...
+
+def from_dlpack(
+    ext_tensor: Any, *, require_alignment: int = ..., require_contiguous: bool 
= ...
+) -> Tensor: ...
+
+class DLTensorTestWrapper:
+    """Wrapper of a Tensor that exposes DLPack protocol, only for testing 
purpose."""
+
+    __c_dlpack_from_pyobject__: int
+    def __init__(self, tensor: Tensor) -> None: ...
+    def __tvm_ffi_env_stream__(self) -> int: ...
+    def __dlpack_device__(self) -> tuple[int, int]: ...
+    def __dlpack__(self, **kwargs: Any) -> Any: ...
+
+def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...
+
+class Function(Object):
+    """Python class that wraps a function with tvm-ffi ABI."""
+
+    @property
+    def release_gil(self) -> bool: ...
+    @release_gil.setter
+    def release_gil(self, value: bool) -> None: ...
+    def __call__(self, *args: Any) -> Any: ...
+
+def _register_global_func(
+    name: str, pyfunc: Callable[..., Any] | Function, override: bool
+) -> Function: ...
+def _get_global_func(name: str, allow_missing: bool) -> Function | None: ...
+def _convert_to_ffi_func(pyfunc: Callable[..., Any]) -> Function: ...
+def _convert_to_opaque_object(pyobject: Any) -> OpaquePyObject: ...
+def _print_debug_info() -> None: ...
+
+class String(str, PyNativeObject):
+    __slots__ = ["__tvm_ffi_object__"]
+    __tvm_ffi_object__: Object | None
+
+    def __new__(cls, value: str) -> String: ...
+
+    # pylint: disable=no-self-argument
+    def __from_tvm_ffi_object__(cls, obj: Any) -> String: ...
+
+class Bytes(bytes, PyNativeObject):
+    __slots__ = ["__tvm_ffi_object__"]
+    __tvm_ffi_object__: Object | None
+
+    def __new__(cls, value: bytes) -> Bytes: ...
+
+    # pylint: disable=no-self-argument
+    def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ...
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index 5ad7571..da30464 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -148,6 +148,7 @@ FMT_MAP = {
     "go": header_cstyle,
     "java": header_cstyle,
     "h": header_cstyle,
+    "pyi": header_pystyle,
     "py": header_pystyle,
     "toml": header_pystyle,
     "yml": header_pystyle,

Reply via email to