This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 785e8ca feat: Add typestubs for Cython generated code (#24)
785e8ca is described below
commit 785e8ca1366b718cb95ae5521fbd39a06a36b778
Author: Junru Shao <[email protected]>
AuthorDate: Thu Sep 18 19:52:55 2025 -0700
feat: Add typestubs for Cython generated code (#24)
After this PR, if Pylance still reports missing imports `core`, you may
set:
```
"python.analysis.extraPaths": ["./build/"]
```
in base directory.
---
python/tvm_ffi/_tensor.py | 18 +++-
python/tvm_ffi/core.pyi | 227 +++++++++++++++++++++++++++++++++++++++++
tests/lint/check_asf_header.py | 1 +
3 files changed, 241 insertions(+), 5 deletions(-)
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index 903a69d..8d06bd2 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -21,12 +21,20 @@
from numbers import Integral
from typing import Any, Optional, Union
-from . import _ffi_api, core, registry
-from .core import Device, DLDeviceType, Tensor, from_dlpack
+from . import _ffi_api, registry
+from .core import (
+ _CLASS_DEVICE,
+ Device,
+ DLDeviceType,
+ PyNativeObject,
+ Tensor,
+ _shape_obj_get_py_tuple,
+ from_dlpack,
+)
@registry.register_object("ffi.Shape")
-class Shape(tuple, core.PyNativeObject):
+class Shape(tuple, PyNativeObject):
"""Shape tuple that represents `ffi::Shape` returned by a ffi call.
Note:
@@ -46,7 +54,7 @@ class Shape(tuple, core.PyNativeObject):
# pylint: disable=no-self-argument
def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape":
"""Construct from a given tvm object."""
- content = core._shape_obj_get_py_tuple(obj)
+ content = _shape_obj_get_py_tuple(obj)
val = tuple.__new__(cls, content)
val.__tvm_ffi_object__ = obj
return val
@@ -78,7 +86,7 @@ def device(device_type: Union[str, int, DLDeviceType], index:
Optional[int] = No
assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0)
"""
- return core._CLASS_DEVICE(device_type, index)
+ return _CLASS_DEVICE(device_type, index)
__all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"]
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
new file mode 100644
index 0000000..d57d020
--- /dev/null
+++ b/python/tvm_ffi/core.pyi
@@ -0,0 +1,227 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Typestubs for Cython."""
+
+from __future__ import annotations
+
+import types
+from enum import IntEnum
+from typing import Any, Callable
+
+# Public module-level variables referenced by Python code
+ERROR_NAME_TO_TYPE: dict[str, type]
+ERROR_TYPE_TO_NAME: dict[type, str]
+
+_WITH_APPEND_TRACEBACK: Callable[[BaseException, str], BaseException] | None
+_TRACEBACK_TO_STR: Callable[[types.TracebackType | None], str] | None
+
+# DLPack protocol version (defined in tensor.pxi)
+__dlpack_version__: tuple[int, int]
+
+class Object:
+ """Base class of all TVM FFI objects."""
+
+ def __ctypes_handle__(self) -> Any: ...
+ def __chandle__(self) -> int: ...
+ def __reduce__(self) -> Any: ...
+ def __getstate__(self) -> dict[str, Any]: ...
+ def __setstate__(self, state: dict[str, Any]) -> None: ...
+ def __repr__(self) -> str: ...
+ def __eq__(self, other: Any) -> bool: ...
+ def __ne__(self, other: Any) -> bool: ...
+ def __hash__(self) -> int: ...
+ def __init_handle_by_constructor__(self, fconstructor: Function, *args:
Any) -> None: ...
+ def same_as(self, other: Any) -> bool: ...
+ def _move(self) -> ObjectRValueRef: ...
+ def __move_handle_from__(self, other: Object) -> None: ...
+
+class ObjectConvertible:
+ """Base class for all classes that can be converted to Object."""
+
+ def asobject(self) -> Object: ...
+
+class ObjectRValueRef:
+ """Represent an RValue ref to an object that can be moved."""
+
+ obj: Object
+ def __init__(self, obj: Object) -> None: ...
+
+class OpaquePyObject(Object):
+ """Opaque PyObject container."""
+
+ def pyobject(self) -> Any: ...
+
+class PyNativeObject:
+ """Base class of all TVM objects that also subclass python's builtin
types."""
+
+ __slots__: list[str]
+ def __init_tvm_ffi_object_by_constructor__(
+ self, fconstructor: Function, *args: Any
+ ) -> None: ...
+
+def _set_class_object(cls: type) -> None: ...
+def _register_object_by_index(index: int, cls: type) -> None: ...
+def _object_type_key_to_index(type_key: str) -> int | None: ...
+def _add_class_attrs_by_reflection(type_index: int, cls: type) -> type: ...
+
+class Error(Object):
+ """Base class for FFI errors."""
+
+ def __init__(self, kind: str, message: str, traceback: str) -> None: ...
+ def update_traceback(self, traceback: str) -> None: ...
+ def py_error(self) -> BaseException: ...
+ @property
+ def kind(self) -> str: ...
+ @property
+ def message(self) -> str: ...
+ @property
+ def traceback(self) -> str: ...
+
+def _convert_to_ffi_error(error: BaseException) -> Error: ...
+def _env_set_current_stream(device_type: int, device_id: int, stream: int) ->
int: ...
+
+class DataType:
+ """DataType wrapper around DLDataType."""
+
+ def __init__(self, dtype_str: str) -> None: ...
+ def __reduce__(self) -> Any: ...
+ def __eq__(self, other: Any) -> bool: ...
+ def __ne__(self, other: Any) -> bool: ...
+ @property
+ def type_code(self) -> int: ...
+ @property
+ def bits(self) -> int: ...
+ @property
+ def lanes(self) -> int: ...
+ @property
+ def itemsize(self) -> int: ...
+ def __str__(self) -> str: ...
+
+def _set_class_dtype(cls: type) -> None: ...
+def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes:
int) -> DataType: ...
+
+class DLDeviceType(IntEnum):
+ """Enum that maps to DLDeviceType."""
+
+ kDLCPU = 1
+ kDLCUDA = 2
+ kDLCUDAHost = 3
+ kDLOpenCL = 4
+ kDLVulkan = 7
+ kDLMetal = 8
+ kDLVPI = 9
+ kDLROCM = 10
+ kDLROCMHost = 11
+ kDLExtDev = 12
+ kDLCUDAManaged = 13
+ kDLOneAPI = 14
+ kDLWebGPU = 15
+ kDLHexagon = 16
+
+class Device:
+ """Device represents a device in the ffi system."""
+
+ def __init__(self, device_type: str | int, index: int | None = None) ->
None: ...
+ def __reduce__(self) -> Any: ...
+ def __eq__(self, other: Any) -> bool: ...
+ def __ne__(self, other: Any) -> bool: ...
+ def __str__(self) -> str: ...
+ def __repr__(self) -> str: ...
+ def __hash__(self) -> int: ...
+ def __device_type_name__(self) -> str: ...
+ @property
+ def type(self) -> str: ...
+ @property
+ def index(self) -> int: ...
+ def dlpack_device_type(self) -> int: ...
+
+def _set_class_device(cls: type) -> None: ...
+
+_CLASS_DEVICE: type[Device]
+
+def _shape_obj_get_py_tuple(obj: Any) -> tuple[int, ...]: ...
+
+class Tensor(Object):
+ """Tensor object that represents a managed n-dimensional array."""
+
+ @property
+ def shape(self) -> tuple[int, ...]: ...
+ @property
+ def dtype(self) -> Any: ... # returned as python dtype (str subclass)
+ @property
+ def device(self) -> Device: ...
+ def _to_dlpack(self) -> Any: ...
+ def _to_dlpack_versioned(self) -> Any: ...
+ def __dlpack_device__(self) -> tuple[int, int]: ...
+ def __dlpack__(
+ self,
+ *,
+ stream: Any | None = None,
+ max_version: tuple[int, int] | None = None,
+ dl_device: tuple[int, int] | None = None,
+ copy: bool | None = None,
+ ) -> Any: ...
+
+def from_dlpack(
+ ext_tensor: Any, *, require_alignment: int = ..., require_contiguous: bool
= ...
+) -> Tensor: ...
+
+class DLTensorTestWrapper:
+ """Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose."""
+
+ __c_dlpack_from_pyobject__: int
+ def __init__(self, tensor: Tensor) -> None: ...
+ def __tvm_ffi_env_stream__(self) -> int: ...
+ def __dlpack_device__(self) -> tuple[int, int]: ...
+ def __dlpack__(self, **kwargs: Any) -> Any: ...
+
+def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...
+
+class Function(Object):
+ """Python class that wraps a function with tvm-ffi ABI."""
+
+ @property
+ def release_gil(self) -> bool: ...
+ @release_gil.setter
+ def release_gil(self, value: bool) -> None: ...
+ def __call__(self, *args: Any) -> Any: ...
+
+def _register_global_func(
+ name: str, pyfunc: Callable[..., Any] | Function, override: bool
+) -> Function: ...
+def _get_global_func(name: str, allow_missing: bool) -> Function | None: ...
+def _convert_to_ffi_func(pyfunc: Callable[..., Any]) -> Function: ...
+def _convert_to_opaque_object(pyobject: Any) -> OpaquePyObject: ...
+def _print_debug_info() -> None: ...
+
+class String(str, PyNativeObject):
+ __slots__ = ["__tvm_ffi_object__"]
+ __tvm_ffi_object__: Object | None
+
+ def __new__(cls, value: str) -> String: ...
+
+ # pylint: disable=no-self-argument
+ def __from_tvm_ffi_object__(cls, obj: Any) -> String: ...
+
+class Bytes(bytes, PyNativeObject):
+ __slots__ = ["__tvm_ffi_object__"]
+ __tvm_ffi_object__: Object | None
+
+ def __new__(cls, value: bytes) -> Bytes: ...
+
+ # pylint: disable=no-self-argument
+ def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ...
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index 5ad7571..da30464 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -148,6 +148,7 @@ FMT_MAP = {
"go": header_cstyle,
"java": header_cstyle,
"h": header_cstyle,
+ "pyi": header_pystyle,
"py": header_pystyle,
"toml": header_pystyle,
"yml": header_pystyle,