This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 40e9c83 feat: Enable mypy type checking (#35)
40e9c83 is described below
commit 40e9c83345636b187f39df86a077e0b56683c792
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 22 12:58:15 2025 -0700
feat: Enable mypy type checking (#35)
With this PR, we are able to enable mypy type checking for the following
directories:
- `python/` and `tests/`
- `examples/inline_module`
- `examples/packaging`
- `examples/quick_start`
---
.pre-commit-config.yaml | 27 +++++
.../packaging/python/my_ffi_extension/__init__.py | 2 +-
examples/quick_start/run_example.py | 18 +---
pyproject.toml | 4 +
python/tvm_ffi/_convert.py | 11 +-
python/tvm_ffi/_dtype.py | 1 +
python/tvm_ffi/_ffi_api.pyi | 43 ++++++++
python/tvm_ffi/_optional_torch_c_dlpack.py | 12 +--
python/tvm_ffi/_tensor.py | 20 ++--
python/tvm_ffi/access_path.py | 27 +++--
python/tvm_ffi/container.py | 10 +-
python/tvm_ffi/core.pyi | 17 +--
python/tvm_ffi/cpp/load_inline.py | 115 +++++++++++----------
python/tvm_ffi/cython/type_info.pxi | 11 +-
python/tvm_ffi/dataclasses/_utils.py | 23 ++---
python/tvm_ffi/dataclasses/c_class.py | 13 ++-
python/tvm_ffi/dataclasses/field.py | 11 +-
python/tvm_ffi/error.py | 12 ++-
python/tvm_ffi/libinfo.py | 37 ++++---
python/tvm_ffi/module.py | 4 +-
python/tvm_ffi/registry.py | 62 +++++++----
python/tvm_ffi/serialization.py | 6 +-
python/tvm_ffi/stream.py | 10 +-
python/tvm_ffi/testing.py | 19 +++-
python/tvm_ffi/utils/lockfile.py | 16 +--
tests/lint/check_asf_header.py | 2 +-
tests/lint/check_file_type.py | 2 +-
tests/python/test_container.py | 4 +-
tests/python/test_dataclasses_c_class.py | 2 +-
tests/python/test_device.py | 4 +-
tests/python/test_error.py | 20 ++--
tests/python/test_function.py | 10 +-
tests/python/test_load_inline.py | 8 +-
tests/python/test_object.py | 14 ++-
tests/python/test_stream.py | 10 +-
tests/python/test_tensor.py | 13 ++-
tests/scripts/benchmark_dlpack.py | 6 +-
37 files changed, 395 insertions(+), 231 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8f762dc..329c00c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -83,3 +83,30 @@ repos:
rev: v0.10.0.1
hooks:
- id: shellcheck
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: "v1.17.0"
+ hooks:
+ - id: mypy
+ name: mypy for `python/` and `tests/`
+ additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1",
"pytest", "typing-extensions>=4.5"]
+ args: [--show-error-codes, --python-version=3.9]
+ exclude: ^.*/_ffi_api\.py$
+ files: ^(python/|tests/).*\.py$
+ - id: mypy
+ name: mypy for `examples/inline_module`
+ additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1",
"pytest", "typing-extensions>=4.5"]
+ args: [--show-error-codes, --python-version=3.9]
+ exclude: ^.*/_ffi_api\.py$
+ files: ^examples/inline_module/.*\.py$
+ - id: mypy
+ name: mypy for `examples/packaging`
+ additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1",
"pytest", "typing-extensions>=4.5"]
+ args: [--show-error-codes, --python-version=3.9]
+ exclude: ^.*/_ffi_api\.py$
+ files: ^examples/packaging/.*\.py$
+ - id: mypy
+ name: mypy for `examples/quick_start`
+ additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1",
"pytest", "typing-extensions>=4.5"]
+ args: [--show-error-codes, --python-version=3.9]
+ exclude: ^.*/_ffi_api\.py$
+ files: ^examples/quick_start/.*\.py$
diff --git a/examples/packaging/python/my_ffi_extension/__init__.py
b/examples/packaging/python/my_ffi_extension/__init__.py
index ae4abfd..0c2b0fd 100644
--- a/examples/packaging/python/my_ffi_extension/__init__.py
+++ b/examples/packaging/python/my_ffi_extension/__init__.py
@@ -51,4 +51,4 @@ def raise_error(msg: str) -> None:
The error raised by the function.
"""
- return _ffi_api.raise_error(msg)
+ return _ffi_api.raise_error(msg) # type: ignore[attr-defined]
diff --git a/examples/quick_start/run_example.py
b/examples/quick_start/run_example.py
index 87c9507..2e2b7f3 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -16,15 +16,9 @@
# under the License.
"""Quick start script to run tvm-ffi examples from prebuilt libraries."""
-import tvm_ffi
-
-try:
- import torch
-except ImportError:
- torch = None
-
-
import numpy
+import torch
+import tvm_ffi
def run_add_one_cpu() -> None:
@@ -40,9 +34,6 @@ def run_add_one_cpu() -> None:
print("numpy.result after add_one(x, y)")
print(x)
- if torch is None:
- return
-
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
# tvm-ffi automatically handles DLPack compatible tensors
@@ -63,9 +54,6 @@ def run_add_one_c() -> None:
print("numpy.result after add_one_c(x, y)")
print(x)
- if torch is None:
- return
-
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
mod.add_one_c(x, y)
@@ -75,7 +63,7 @@ def run_add_one_c() -> None:
def run_add_one_cuda() -> None:
"""Load the add_one_cuda module and call the add_one_cuda function."""
- if torch is None or not torch.cuda.is_available():
+ if not torch.cuda.is_available():
return
mod = tvm_ffi.load_module("build/add_one_cuda.so")
diff --git a/pyproject.toml b/pyproject.toml
index d2fd897..58d394a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -213,3 +213,7 @@ environment = { MACOSX_DEPLOYMENT_TARGET = "10.14" }
[tool.cibuildwheel.windows]
archs = ["AMD64"]
+
+[tool.mypy]
+allow_redefinition = true
+ignore_missing_imports = true
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 6313c8e..05e99da 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -16,20 +16,25 @@
# under the License.
"""Conversion utilities to bring python objects into ffi values."""
+from __future__ import annotations
+
from numbers import Number
+from types import ModuleType
from typing import Any
from . import container, core
+torch: ModuleType | None = None
try:
- import torch
+ import torch # type: ignore[no-redef]
except ImportError:
- torch = None
+ pass
+numpy: ModuleType | None = None
try:
import numpy
except ImportError:
- numpy = None
+ pass
def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index ba1735f..8306593 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -59,6 +59,7 @@ class dtype(str):
"""
__slots__ = ["__tvm_ffi_dtype__"]
+ __tvm_ffi_dtype__: core.DataType
_NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}
diff --git a/python/tvm_ffi/_ffi_api.pyi b/python/tvm_ffi/_ffi_api.pyi
new file mode 100644
index 0000000..95059e5
--- /dev/null
+++ b/python/tvm_ffi/_ffi_api.pyi
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI API."""
+
+from typing import Any
+
+def ModuleGetKind(*args: Any) -> Any: ...
+def ModuleImplementsFunction(*args: Any) -> Any: ...
+def ModuleGetFunction(*args: Any) -> Any: ...
+def ModuleImportModule(*args: Any) -> Any: ...
+def ModuleInspectSource(*args: Any) -> Any: ...
+def ModuleGetWriteFormats(*args: Any) -> Any: ...
+def ModuleGetPropertyMask(*args: Any) -> Any: ...
+def ModuleClearImports(*args: Any) -> Any: ...
+def ModuleWriteToFile(*args: Any) -> Any: ...
+def ModuleLoadFromFile(*args: Any) -> Any: ...
+def SystemLib(*args: Any) -> Any: ...
+def Array(*args: Any) -> Any: ...
+def ArrayGetItem(*args: Any) -> Any: ...
+def ArraySize(*args: Any) -> Any: ...
+def MapForwardIterFunctor(*args: Any) -> Any: ...
+def Map(*args: Any) -> Any: ...
+def MapGetItem(*args: Any) -> Any: ...
+def MapCount(*args: Any) -> Any: ...
+def MapSize(*args: Any) -> Any: ...
+def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
+def ToJSONGraphString(*args: Any) -> Any: ...
+def FromJSONGraphString(*args: Any) -> Any: ...
+def Shape(*args: Any) -> Any: ...
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index dd820a9..5be7211 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -31,12 +31,12 @@ subsequent calls will be much faster.
"""
import warnings
-from typing import Any, Optional
+from typing import Any
from . import libinfo
-def load_torch_c_dlpack_extension() -> Optional[Any]:
+def load_torch_c_dlpack_extension() -> Any:
"""Load the torch c dlpack extension."""
cpp_source = """
#include <dlpack/dlpack.h>
@@ -556,9 +556,9 @@ int64_t TorchDLPackTensorAllocatorPtr() {
extra_include_paths=include_paths,
)
# set the dlpack related flags
- torch.Tensor.__c_dlpack_from_pyobject__ =
mod.TorchDLPackFromPyObjectPtr()
- torch.Tensor.__c_dlpack_to_pyobject__ = mod.TorchDLPackToPyObjectPtr()
- torch.Tensor.__c_dlpack_tensor_allocator__ =
mod.TorchDLPackTensorAllocatorPtr()
+ setattr(torch.Tensor, "__c_dlpack_from_pyobject__",
mod.TorchDLPackFromPyObjectPtr())
+ setattr(torch.Tensor, "__c_dlpack_to_pyobject__",
mod.TorchDLPackToPyObjectPtr())
+ setattr(torch.Tensor, "__c_dlpack_tensor_allocator__",
mod.TorchDLPackTensorAllocatorPtr())
return mod
except ImportError:
pass
@@ -566,7 +566,7 @@ int64_t TorchDLPackTensorAllocatorPtr() {
warnings.warn(
f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator
will not be enabled."
)
- return None
+ return None
# keep alive
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index 0cc09f1..0d44994 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Tensor related objects and functions."""
+
+from __future__ import annotations
+
# we name it as _tensor.py to avoid potential future case
# if we also want to expose a tensor function in the root namespace
-
from numbers import Integral
-from typing import Any, Optional, Union
+from typing import Any
from . import _ffi_api, core, registry
from .core import (
@@ -43,23 +45,25 @@ class Shape(tuple, PyNativeObject):
"""
- def __new__(cls, content: tuple[int, ...]) -> "Shape":
+ __tvm_ffi_object__: Any
+
+ def __new__(cls, content: tuple[int, ...]) -> Shape:
if any(not isinstance(x, Integral) for x in content):
raise ValueError("Shape must be a tuple of integers")
- val = tuple.__new__(cls, content)
+ val: Shape = tuple.__new__(cls, content)
val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content)
return val
# pylint: disable=no-self-argument
- def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape":
+ def __from_tvm_ffi_object__(cls, obj: Any) -> Shape:
"""Construct from a given tvm object."""
content = _shape_obj_get_py_tuple(obj)
- val = tuple.__new__(cls, content)
- val.__tvm_ffi_object__ = obj
+ val: Shape = tuple.__new__(cls, content) # type: ignore[arg-type]
+ val.__tvm_ffi_object__ = obj # type: ignore[attr-defined]
return val
-def device(device_type: Union[str, int, DLDeviceType], index: Optional[int] =
None) -> Device:
+def device(device_type: str | int | DLDeviceType, index: int | None = None) ->
Device:
"""Construct a TVM FFI device with given device type and index.
Parameters
diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py
index e8aec10..aa52d58 100644
--- a/python/tvm_ffi/access_path.py
+++ b/python/tvm_ffi/access_path.py
@@ -39,11 +39,16 @@ class AccessKind(IntEnum):
class AccessStep(core.Object):
"""Access step container."""
+ kind: AccessKind
+ key: Any
+
@register_object("ffi.reflection.AccessPath")
class AccessPath(core.Object):
"""Access path container."""
+ parent: "AccessPath"
+
def __init__(self) -> None:
"""Disallow direct construction; use `AccessPath.root()` instead."""
super().__init__()
@@ -55,19 +60,19 @@ class AccessPath(core.Object):
@staticmethod
def root() -> "AccessPath":
"""Create a root access path."""
- return AccessPath._root()
+ return AccessPath._root() # type: ignore[attr-defined]
def __eq__(self, other: Any) -> bool:
"""Return whether two access paths are equal."""
if not isinstance(other, AccessPath):
return False
- return self._path_equal(other)
+ return self._path_equal(other) # type: ignore[attr-defined]
def __ne__(self, other: Any) -> bool:
"""Return whether two access paths are not equal."""
if not isinstance(other, AccessPath):
return True
- return not self._path_equal(other)
+ return not self._path_equal(other) # type: ignore[attr-defined]
def is_prefix_of(self, other: "AccessPath") -> bool:
"""Check if this access path is a prefix of another access path.
@@ -83,7 +88,7 @@ class AccessPath(core.Object):
True if this access path is a prefix of the other access path,
False otherwise
"""
- return self._is_prefix_of(other)
+ return self._is_prefix_of(other) # type: ignore[attr-defined]
def attr(self, attr_key: str) -> "AccessPath":
"""Create an access path to the attribute of the current object.
@@ -99,7 +104,7 @@ class AccessPath(core.Object):
The extended access path
"""
- return self._attr(attr_key)
+ return self._attr(attr_key) # type: ignore[attr-defined]
def attr_missing(self, attr_key: str) -> "AccessPath":
"""Create an access path that indicate an attribute is missing.
@@ -115,7 +120,7 @@ class AccessPath(core.Object):
The extended access path
"""
- return self._attr_missing(attr_key)
+ return self._attr_missing(attr_key) # type: ignore[attr-defined]
def array_item(self, index: int) -> "AccessPath":
"""Create an access path to the item of the current array.
@@ -131,7 +136,7 @@ class AccessPath(core.Object):
The extended access path
"""
- return self._array_item(index)
+ return self._array_item(index) # type: ignore[attr-defined]
def array_item_missing(self, index: int) -> "AccessPath":
"""Create an access path that indicate an array item is missing.
@@ -147,7 +152,7 @@ class AccessPath(core.Object):
The extended access path
"""
- return self._array_item_missing(index)
+ return self._array_item_missing(index) # type: ignore[attr-defined]
def map_item(self, key: Any) -> "AccessPath":
"""Create an access path to the item of the current map.
@@ -163,7 +168,7 @@ class AccessPath(core.Object):
The extended access path
"""
- return self._map_item(key)
+ return self._map_item(key) # type: ignore[attr-defined]
def map_item_missing(self, key: Any) -> "AccessPath":
"""Create an access path that indicate a map item is missing.
@@ -179,7 +184,7 @@ class AccessPath(core.Object):
The extended access path
"""
- return self._map_item_missing(key)
+ return self._map_item_missing(key) # type: ignore[attr-defined]
def to_steps(self) -> list["AccessStep"]:
"""Convert the access path to a list of access steps.
@@ -190,6 +195,6 @@ class AccessPath(core.Object):
The list of access steps
"""
- return self._to_steps()
+ return self._to_steps() # type: ignore[attr-defined]
__hash__ = core.Object.__hash__
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 008dda9..6f29dfd 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -16,6 +16,8 @@
# under the License.
"""Container classes."""
+from __future__ import annotations
+
import collections.abc
from collections.abc import Iterator, Mapping, Sequence
from typing import Any, Callable
@@ -121,7 +123,7 @@ class Array(core.Object, collections.abc.Sequence):
class KeysView(collections.abc.KeysView):
"""Helper class to return keys view."""
- def __init__(self, backend_map: "Map") -> None:
+ def __init__(self, backend_map: Map) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
@@ -144,7 +146,7 @@ class KeysView(collections.abc.KeysView):
class ValuesView(collections.abc.ValuesView):
"""Helper class to return values view."""
- def __init__(self, backend_map: "Map") -> None:
+ def __init__(self, backend_map: Map) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
@@ -164,7 +166,7 @@ class ValuesView(collections.abc.ValuesView):
class ItemsView(collections.abc.ItemsView):
"""Helper class to return items view."""
- def __init__(self, backend_map: "Map") -> None:
+ def __init__(self, backend_map: Map) -> None:
self.backend_map = backend_map
def __len__(self) -> int:
@@ -231,7 +233,7 @@ class Map(core.Object, collections.abc.Mapping):
"""Return a dynamic view of the map's keys."""
return KeysView(self)
- def values(self) -> "ValuesView":
+ def values(self) -> ValuesView:
"""Return a dynamic view of the map's values."""
return ValuesView(self)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index cfb3ea9..7c623af 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -19,6 +19,7 @@
from __future__ import annotations
import types
+from ctypes import c_void_p
from enum import IntEnum
from typing import Any, Callable
@@ -28,7 +29,6 @@ ERROR_TYPE_TO_NAME: dict[type, str]
_WITH_APPEND_BACKTRACE: Callable[[BaseException, str], BaseException] | None
_TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType | None], str] | None
-
# DLPack protocol version (defined in tensor.pxi)
__dlpack_version__: tuple[int, int]
@@ -44,7 +44,7 @@ class Object:
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __hash__(self) -> int: ...
- def __init_handle_by_constructor__(self, fconstructor: Function, *args:
Any) -> None: ...
+ def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) ->
None: ...
def __ffi_init__(self, *args: Any) -> None:
"""Initialize the instance using the ` __init__` method registered on
C++ side.
@@ -78,9 +78,7 @@ class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin
types."""
__slots__: list[str]
- def __init_tvm_ffi_object_by_constructor__(
- self, fconstructor: Function, *args: Any
- ) -> None: ...
+ def __init_tvm_ffi_object_by_constructor__(self, fconstructor: Any, *args:
Any) -> None: ...
def _set_class_object(cls: type) -> None: ...
def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
@@ -101,7 +99,9 @@ class Error(Object):
def backtrace(self) -> str: ...
def _convert_to_ffi_error(error: BaseException) -> Error: ...
-def _env_set_current_stream(device_type: int, device_id: int, stream: int) ->
int: ...
+def _env_set_current_stream(
+ device_type: int, device_id: int, stream: int | c_void_p
+) -> int | c_void_p: ...
class DataType:
"""DataType wrapper around DLDataType."""
@@ -121,6 +121,8 @@ class DataType:
def __str__(self) -> str: ...
def _set_class_dtype(cls: type) -> None: ...
+def _convert_torch_dtype_to_ffi_dtype(torch_dtype: Any) -> DataType: ...
+def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype: Any) -> DataType: ...
def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes:
int) -> DataType: ...
class DLDeviceType(IntEnum):
@@ -185,6 +187,9 @@ class Tensor(Object):
copy: bool | None = None,
) -> Any: ...
+_CLASS_TENSOR: type[Tensor] = Tensor
+
+def _set_class_tensor(cls: type[Tensor]) -> None: ...
def from_dlpack(
ext_tensor: Any, *, require_alignment: int = ..., require_contiguous: bool
= ...
) -> Tensor: ...
diff --git a/python/tvm_ffi/cpp/load_inline.py
b/python/tvm_ffi/cpp/load_inline.py
index 2c1caf5..d7a5c14 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -16,6 +16,8 @@
# under the License.
"""Build and load inline C++/CUDA sources into a tvm_ffi Module using Ninja."""
+from __future__ import annotations
+
import functools
import hashlib
import os
@@ -24,7 +26,6 @@ import subprocess
import sys
from collections.abc import Mapping, Sequence
from pathlib import Path
-from typing import Optional
from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path,
find_libtvm_ffi
from tvm_ffi.module import Module, load_module
@@ -77,7 +78,7 @@ def _maybe_write(path: str, content: str) -> None:
@functools.lru_cache
-def _find_cuda_home() -> Optional[str]:
+def _find_cuda_home() -> str:
"""Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
@@ -92,9 +93,10 @@ def _find_cuda_home() -> Optional[str]:
cuda_root = Path("C:/Program Files/NVIDIA GPU Computing
Toolkit/CUDA")
cuda_homes = list(cuda_root.glob("v*.*"))
if len(cuda_homes) == 0:
- cuda_home = ""
- else:
- cuda_home = str(cuda_homes[0])
+ raise RuntimeError(
+ "Could not find CUDA installation. Please set
CUDA_HOME environment variable."
+ )
+ cuda_home = str(cuda_homes[0])
else:
cuda_home = "/usr/local/cuda"
if not Path(cuda_home).exists():
@@ -358,17 +360,17 @@ def _decorate_with_tvm_ffi(source: str, functions:
Mapping[str, str]) -> str:
return "\n".join(sources)
-def load_inline(
+def load_inline( # noqa: PLR0912, PLR0915
name: str,
*,
- cpp_sources: str | None = None,
- cuda_sources: str | None = None,
- functions: Sequence[str] | None = None,
+ cpp_sources: Sequence[str] | str | None = None,
+ cuda_sources: Sequence[str] | str | None = None,
+ functions: Mapping[str, str] | Sequence[str] | str | None = None,
extra_cflags: Sequence[str] | None = None,
extra_cuda_cflags: Sequence[str] | None = None,
extra_ldflags: Sequence[str] | None = None,
extra_include_paths: Sequence[str] | None = None,
- build_directory: Optional[str] = None,
+ build_directory: str | None = None,
) -> Module:
"""Compile and load a C++/CUDA module from inline source code.
@@ -481,76 +483,83 @@ def load_inline(
"""
if cpp_sources is None:
- cpp_sources = []
+ cpp_source_list: list[str] = []
elif isinstance(cpp_sources, str):
- cpp_sources = [cpp_sources]
- cpp_source = "\n".join(cpp_sources)
+ cpp_source_list = [cpp_sources]
+ else:
+ cpp_source_list = list(cpp_sources)
+ cpp_source = "\n".join(cpp_source_list)
+ with_cpp = bool(cpp_source_list)
+ del cpp_source_list
+
if cuda_sources is None:
- cuda_sources = []
+ cuda_source_list: list[str] = []
elif isinstance(cuda_sources, str):
- cuda_sources = [cuda_sources]
- cuda_source = "\n".join(cuda_sources)
- with_cpp = len(cpp_sources) > 0
- with_cuda = len(cuda_sources) > 0
+ cuda_source_list = [cuda_sources]
+ else:
+ cuda_source_list = list(cuda_sources)
+ cuda_source = "\n".join(cuda_source_list)
+ with_cuda = bool(cuda_source_list)
+ del cuda_source_list
- extra_ldflags = extra_ldflags or []
- extra_cflags = extra_cflags or []
- extra_cuda_cflags = extra_cuda_cflags or []
- extra_include_paths = extra_include_paths or []
+ extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else
[]
+ extra_cflags_list = list(extra_cflags) if extra_cflags is not None else []
+ extra_cuda_cflags_list = list(extra_cuda_cflags) if extra_cuda_cflags is
not None else []
+ extra_include_paths_list = list(extra_include_paths) if
extra_include_paths is not None else []
# add function registration code to sources
- if isinstance(functions, str):
- functions = {functions: ""}
- elif isinstance(functions, Sequence):
- functions = {name: "" for name in functions}
+ if functions is None:
+ function_map: dict[str, str] = {}
+ elif isinstance(functions, str):
+ function_map = {functions: ""}
+ elif isinstance(functions, Mapping):
+ function_map = dict(functions)
+ else:
+ function_map = {name: "" for name in functions}
if with_cpp:
- cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
+ cpp_source = _decorate_with_tvm_ffi(cpp_source, function_map)
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
else:
cpp_source = _decorate_with_tvm_ffi(cpp_source, {})
- cuda_source = _decorate_with_tvm_ffi(cuda_source, functions)
+ cuda_source = _decorate_with_tvm_ffi(cuda_source, function_map)
# determine the cache dir for the built module
+ build_dir: Path
if build_directory is None:
- build_directory = os.environ.get(
- "TVM_FFI_CACHE_DIR", str(Path("~/.cache/tvm-ffi").expanduser())
- )
+ cache_dir = os.environ.get("TVM_FFI_CACHE_DIR",
str(Path("~/.cache/tvm-ffi").expanduser()))
source_hash: str = _hash_sources(
cpp_source,
cuda_source,
- functions,
- extra_cflags,
- extra_cuda_cflags,
- extra_ldflags,
- extra_include_paths,
+ function_map,
+ extra_cflags_list,
+ extra_cuda_cflags_list,
+ extra_ldflags_list,
+ extra_include_paths_list,
)
- build_dir: str = str(Path(build_directory) / f"{name}_{source_hash}")
+ build_dir = Path(cache_dir).expanduser() / f"{name}_{source_hash}"
else:
- build_dir = str(Path(build_directory).resolve())
- Path(build_dir).mkdir(parents=True, exist_ok=True)
+ build_dir = Path(build_directory).resolve()
+ build_dir.mkdir(parents=True, exist_ok=True)
# generate build.ninja
ninja_source = _generate_ninja_build(
name=name,
- build_dir=build_dir,
+ build_dir=str(build_dir),
with_cuda=with_cuda,
- extra_cflags=extra_cflags,
- extra_cuda_cflags=extra_cuda_cflags,
- extra_ldflags=extra_ldflags,
- extra_include_paths=extra_include_paths,
+ extra_cflags=extra_cflags_list,
+ extra_cuda_cflags=extra_cuda_cflags_list,
+ extra_ldflags=extra_ldflags_list,
+ extra_include_paths=extra_include_paths_list,
)
-
- with FileLock(str(Path(build_dir) / "lock")):
+ with FileLock(str(build_dir / "lock")):
# write source files and build.ninja if they do not already exist
- _maybe_write(str(Path(build_dir) / "main.cpp"), cpp_source)
+ _maybe_write(str(build_dir / "main.cpp"), cpp_source)
if with_cuda:
- _maybe_write(str(Path(build_dir) / "cuda.cu"), cuda_source)
- _maybe_write(str(Path(build_dir) / "build.ninja"), ninja_source)
-
+ _maybe_write(str(build_dir / "cuda.cu"), cuda_source)
+ _maybe_write(str(build_dir / "build.ninja"), ninja_source)
# build the module
- _build_ninja(build_dir)
-
+ _build_ninja(str(build_dir))
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- return load_module(str((Path(build_dir) / f"{name}{ext}").resolve()))
+ return load_module(str((build_dir / f"{name}{ext}").resolve()))
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 4ab9f15..fcd443b 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import dataclasses
+from typing import Optional, Any
cdef class FieldGetter:
@@ -62,13 +63,13 @@ class TypeField:
"""Description of a single reflected field on an FFI-backed type."""
name: str
- doc: str | None
+ doc: Optional[str]
size: int
offset: int
frozen: bool
getter: FieldGetter
setter: FieldSetter
- dataclass_field: object | None = None
+ dataclass_field: Any = None
def __post_init__(self):
assert self.setter is not None
@@ -96,7 +97,7 @@ class TypeMethod:
"""Description of a single reflected method on an FFI-backed type."""
name: str
- doc: str | None
+ doc: Optional[str]
func: object
is_static: bool
@@ -105,9 +106,9 @@ class TypeMethod:
class TypeInfo:
"""Aggregated type information required to build a proxy class."""
- type_cls: type | None
+ type_cls: Optional[type]
type_index: int
type_key: str
fields: list[TypeField]
methods: list[TypeMethod]
- parent_type_info: TypeInfo | None
+ parent_type_info: Optional[TypeInfo]
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
index ef7c7e4..7a28e01 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import functools
import inspect
from dataclasses import MISSING
-from typing import Any, Callable, NamedTuple, TypeVar
+from typing import Any, Callable, NamedTuple, TypeVar, cast
from ..core import (
Object,
@@ -68,7 +68,7 @@ def type_info_to_cls(
attrs[field.name] = field.as_property(cls)
# Step 3. Add methods
- def _add_method(name: str, func: Callable) -> None:
+ def _add_method(name: str, func: Callable[..., Any]) -> None:
if name == "__ffi_init__":
name = "__c_ffi_init__"
if name in attrs: # already defined
@@ -80,9 +80,9 @@ def type_info_to_cls(
attrs[name] = func
setattr(cls, name, func)
- for name, method in methods.items():
- if method is not None:
- _add_method(name, method)
+ for name, method_impl in methods.items():
+ if method_impl is not None:
+ _add_method(name, method_impl)
for method in type_info.methods:
_add_method(method.name, method.func)
@@ -90,7 +90,7 @@ def type_info_to_cls(
new_cls = type(cls.__name__, cls_bases, attrs)
new_cls.__module__ = cls.__module__
new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore
- return new_cls
+ return cast(type[_InputClsType], new_cls)
def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
@@ -123,15 +123,12 @@ def method_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]: #
fn: Callable[[], Any]
- fields: list[TypeInfo] = []
- cur_type_info = type_info
- while True:
+ fields: list[TypeField] = []
+ cur_type_info: TypeInfo | None = type_info
+ while cur_type_info is not None:
fields.extend(reversed(cur_type_info.fields))
cur_type_info = cur_type_info.parent_type_info
- if cur_type_info is None:
- break
fields.reverse()
- del cur_type_info
annotations: dict[str, Any] = {"return": None}
# Step 1. Split the parameters into two groups to ensure that
@@ -187,7 +184,7 @@ def method_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]: #
else:
raise ValueError(f"Cannot find constructor method:
`{type_info.type_key}.__ffi_init__`")
- def __init__(self: type, *args: Any, **kwargs: Any) -> None:
+ def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
e = None
try:
args = bind_args(*args, **kwargs)
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index 7507b76..700f287 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -28,18 +28,21 @@ from dataclasses import InitVar
from typing import ClassVar, TypeVar, get_origin, get_type_hints
from ..core import TypeField, TypeInfo
-from . import _utils, field
+from . import _utils
+from .field import Field, field
try:
- from typing import dataclass_transform
+ from typing_extensions import dataclass_transform # type:
ignore[attr-defined]
except ImportError:
- from typing_extensions import dataclass_transform
+ from typing import dataclass_transform # type:
ignore[no-redef,attr-defined]
+except ImportError:
+ pass
_InputClsType = TypeVar("_InputClsType")
-@dataclass_transform(field_specifiers=(field.field, field.Field))
+@dataclass_transform(field_specifiers=(field, Field))
def c_class(
type_key: str, init: bool = True
) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
@@ -157,7 +160,7 @@ def _inspect_c_class_fields(type_cls: type, type_info:
TypeInfo) -> list[TypeFie
for field_name, _field_ty_py in type_hints_py.items():
if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip
continue
- type_field: TypeField = type_fields_cxx.pop(field_name, None)
+ type_field = type_fields_cxx.pop(field_name, None)
if type_field is None:
raise ValueError(
f"Extraneous field `{type_cls}.{field_name}`. Defined in
Python but not in C++"
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index 00170e5..f1a582e 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -18,11 +18,10 @@
from __future__ import annotations
-from dataclasses import MISSING, dataclass
+from dataclasses import MISSING
from typing import Any, Callable
-@dataclass(kw_only=True)
class Field:
"""(Experimental) Descriptor placeholder returned by
:func:`tvm_ffi.dataclasses.field`.
@@ -36,8 +35,12 @@ class Field:
way the decorator understands.
"""
- name: str | None = None
- default_factory: Callable[[], Any]
+ __slots__ = ("default_factory", "name")
+
+ def __init__(self, *, name: str | None = None, default_factory:
Callable[[], Any]) -> None:
+ """Do not call directly; use :func:`field` instead."""
+ self.name = name
+ self.default_factory = default_factory
def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field:
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index f29fc90..fd0bf2b 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -17,11 +17,13 @@
# pylint: disable=invalid-name
"""Error handling."""
+from __future__ import annotations
+
import ast
import re
import sys
import types
-from typing import Any, Optional
+from typing import Any
from . import core
@@ -60,7 +62,7 @@ class TracebackManager:
def __init__(self) -> None:
"""Initialize the traceback manager and its cache."""
- self._code_cache = {}
+ self._code_cache: dict[tuple[str, int, str], types.CodeType] = {}
def _get_cached_code_object(self, filename: str, lineno: int, func: str)
-> types.CodeType:
# Hack to create a code object that points to the correct
@@ -95,7 +97,7 @@ class TracebackManager:
def append_traceback(
self,
- tb: Optional[types.TracebackType],
+ tb: types.TracebackType | None,
filename: str,
lineno: int,
func: str,
@@ -134,7 +136,7 @@ def _with_append_backtrace(py_error: BaseException,
backtrace: str) -> BaseExcep
return py_error.with_traceback(tb)
-def _traceback_to_backtrace_str(tb: Optional[types.TracebackType]) -> str:
+def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str:
"""Convert the traceback to a string."""
lines = []
while tb is not None:
@@ -155,7 +157,7 @@ core._TRACEBACK_TO_BACKTRACE_STR =
_traceback_to_backtrace_str
def register_error(
name_or_cls: str | type | None = None,
- cls: Optional[type] = None,
+ cls: type | None = None,
) -> Any:
"""Register an error class so it can be recognized by the ffi error
handler.
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index 382690b..8d92df3 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -46,25 +46,24 @@ def split_env_var(env_var: str, split: str) -> list[str]:
def get_dll_directories() -> list[str]:
"""Get the possible dll directories."""
ffi_dir = Path(__file__).expanduser().resolve().parent
- dll_path = [ffi_dir / "lib"]
- dll_path += [ffi_dir / ".." / ".." / "build" / "lib"]
+ dll_path: list[Path] = [ffi_dir / "lib"]
+ dll_path.append(ffi_dir / ".." / ".." / "build" / "lib")
# in source build from parent if needed
- dll_path += [ffi_dir / ".." / ".." / ".." / "build" / "lib"]
-
+ dll_path.append(ffi_dir / ".." / ".." / ".." / "build" / "lib")
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
- dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":"))
- dll_path.extend(split_env_var("PATH", ":"))
+ dll_path.extend(Path(p) for p in split_env_var("LD_LIBRARY_PATH", ":"))
+ dll_path.extend(Path(p) for p in split_env_var("PATH", ":"))
elif sys.platform.startswith("darwin"):
- dll_path.extend(split_env_var("DYLD_LIBRARY_PATH", ":"))
- dll_path.extend(split_env_var("PATH", ":"))
+ dll_path.extend(Path(p) for p in split_env_var("DYLD_LIBRARY_PATH",
":"))
+ dll_path.extend(Path(p) for p in split_env_var("PATH", ":"))
elif sys.platform.startswith("win32"):
- dll_path.extend(split_env_var("PATH", ";"))
- return [str(Path(x).resolve()) for x in dll_path if Path(x).is_dir()]
+ dll_path.extend(Path(p) for p in split_env_var("PATH", ";"))
+ return [str(path.resolve()) for path in dll_path if path.is_dir()]
def find_libtvm_ffi() -> str:
"""Find libtvm_ffi."""
- dll_path = get_dll_directories()
+ dll_path = [Path(p) for p in get_dll_directories()]
if sys.platform.startswith("win32"):
lib_dll_names = ["tvm_ffi.dll"]
elif sys.platform.startswith("darwin"):
@@ -72,14 +71,18 @@ def find_libtvm_ffi() -> str:
else:
lib_dll_names = ["libtvm_ffi.so"]
- name = lib_dll_names
- lib_dll_path = [str(Path(p) / name) for name in lib_dll_names for p in
dll_path]
- lib_found = [p for p in lib_dll_path if Path(p).exists() and
Path(p).is_file()]
+ lib_dll_path = [p / name for name in lib_dll_names for p in dll_path]
+ lib_found = [p for p in lib_dll_path if p.exists() and p.is_file()]
if not lib_found:
- raise RuntimeError(f"Cannot find library: {name}\nList of
candidates:\n{lib_dll_path}")
-
- return lib_found[0]
+ candidate_list = "\n".join(str(p) for p in lib_dll_path)
+ raise RuntimeError(
+ "Cannot find library: {}\nList of candidates:\n{}".format(
+ ", ".join(lib_dll_names), candidate_list
+ )
+ )
+
+ return str(lib_found[0])
def find_source_path() -> str:
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index acdc11e..768463d 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -73,7 +73,7 @@ class Module(core.Object):
The module
"""
- return self.imports_
+ return self.imports_ # type: ignore[return-value]
def implements_function(self, name: str, query_imports: bool = False) ->
bool:
"""Return True if the module defines a global function.
@@ -255,7 +255,7 @@ def system_lib(symbol_prefix: str = "") -> Module:
Parameters
----------
- symbol_prefix: Optional[str]
+ symbol_prefix: str = ""
Optional symbol prefix that can be used for search. When we lookup a
symbol
symbol_prefix + name will first be searched, then the name without
symbol_prefix.
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 6bf08f6..3ef4039 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -16,8 +16,10 @@
# under the License.
"""FFI registry to register function and objects."""
+from __future__ import annotations
+
import sys
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Literal, overload
from . import core
from .core import TypeInfo
@@ -26,7 +28,7 @@ from .core import TypeInfo
_SKIP_UNKNOWN_OBJECTS = False
-def register_object(type_key: str | type | None = None) -> Any:
+def register_object(type_key: str | type | None = None) -> Callable[[type],
type] | type:
"""Register object type.
Parameters
@@ -46,9 +48,8 @@ def register_object(type_key: str | type | None = None) ->
Any:
pass
"""
- object_name = type_key if isinstance(type_key, str) else type_key.__name__
- def register(cls: type) -> type:
+ def _register(cls: type, object_name: str) -> type:
"""Register the object type with the FFI core."""
type_index = core._object_type_key_to_index(object_name)
if type_index is None:
@@ -60,14 +61,25 @@ def register_object(type_key: str | type | None = None) ->
Any:
return cls
if isinstance(type_key, str):
- return register
- return register(type_key)
+ def _decorator_with_name(cls: type) -> type:
+ return _register(cls, type_key)
+
+ return _decorator_with_name
+
+ def _decorator_default(cls: type) -> type:
+ return _register(cls, cls.__name__)
+
+ if type_key is None:
+ return _decorator_default
+ if isinstance(type_key, type):
+ return _decorator_default(type_key)
+ raise TypeError("type_key must be a string, type, or None")
def register_global_func(
func_name: str | Callable[..., Any],
- f: Optional[Callable[..., Any]] = None,
+ f: Callable[..., Any] | None = None,
override: bool = False,
) -> Any:
"""Register global function.
@@ -124,12 +136,20 @@ def register_global_func(
"""Register the global function with the FFI core."""
return core._register_global_func(func_name, myf, override)
- if f:
+ if f is not None:
return register(f)
return register
-def get_global_func(name: str, allow_missing: bool = False) ->
Optional[core.Function]:
+@overload
+def get_global_func(name: str, allow_missing: Literal[True]) -> core.Function
| None: ...
+
+
+@overload
+def get_global_func(name: str, allow_missing: Literal[False] = False) ->
core.Function: ...
+
+
+def get_global_func(name: str, allow_missing: bool = False) -> core.Function |
None:
"""Get a global function by name.
Parameters
@@ -179,7 +199,7 @@ def remove_global_func(name: str) -> None:
get_global_func("ffi.FunctionRemoveGlobal")(name)
-def init_ffi_api(namespace: str, target_module_name: Optional[str] = None) ->
None:
+def init_ffi_api(namespace: str, target_module_name: str | None = None) ->
None:
"""Initialize register ffi api functions into a given module.
Parameters
@@ -225,8 +245,8 @@ def init_ffi_api(namespace: str, target_module_name:
Optional[str] = None) -> No
continue
f = get_global_func(name)
- f.__name__ = fname
- setattr(target_module, f.__name__, f)
+ setattr(f, "__name__", fname)
+ setattr(target_module, fname, f)
def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
@@ -253,16 +273,16 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
doc = method.doc if method.doc else None
method_func = method.func
if method.is_static:
- method_pyfunc = staticmethod(method_func)
+ if doc is not None:
+ method_func.__doc__ = doc
+ method_func.__name__ = name
+ method_pyfunc: Any = staticmethod(method_func)
else:
- # must call into another method instead of direct capture
- # to avoid the same method_func variable being used
- # across multiple loop iterations
- method_pyfunc = _member_method_wrapper(method_func)
-
- if doc is not None:
- method_pyfunc.__doc__ = doc
- method_pyfunc.__name__ = name
+ wrapped_func = _member_method_wrapper(method_func)
+ if doc is not None:
+ wrapped_func.__doc__ = doc
+ wrapped_func.__name__ = name
+ method_pyfunc = wrapped_func
if hasattr(type_cls, name):
# skip already defined attributes
diff --git a/python/tvm_ffi/serialization.py b/python/tvm_ffi/serialization.py
index 2bc0d14..6960eda 100644
--- a/python/tvm_ffi/serialization.py
+++ b/python/tvm_ffi/serialization.py
@@ -16,12 +16,14 @@
# under the License.
"""Serialization related utilities to enable some object can be pickled."""
-from typing import Any, Optional
+from __future__ import annotations
+
+from typing import Any
from . import _ffi_api
-def to_json_graph_str(obj: Any, metadata: Optional[dict] = None) -> str:
+def to_json_graph_str(obj: Any, metadata: dict[str, Any] | None = None) -> str:
"""Dump an object to a JSON graph string.
The JSON graph string is a string representation of of the object
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 3ce9d94..7f2dde5 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -18,7 +18,7 @@
"""Stream context."""
from ctypes import c_void_p
-from typing import Any, NoReturn, Optional, Union
+from typing import Any, Union
from . import core
from ._tensor import device
@@ -72,7 +72,7 @@ try:
class TorchStreamContext:
"""Context manager that syncs Torch and FFI stream contexts."""
- def __init__(self, context: Optional[Any]) -> None:
+ def __init__(self, context: Any) -> None:
"""Initialize with an optional Torch stream/graph context
wrapper."""
self.torch_context = context
@@ -93,14 +93,14 @@ try:
self.torch_context.__exit__(*args)
self.ffi_context.__exit__(*args)
- def use_torch_stream(context: Optional[Any] = None) ->
"TorchStreamContext":
+ def use_torch_stream(context: Any = None) -> "TorchStreamContext":
"""Create an FFI stream context with a Torch stream or graph.
cuda graph or current stream if `None` provided.
Parameters
----------
- context : Optional[Any]
+ context : Any = None
The wrapped torch stream or cuda graph.
Returns
@@ -129,7 +129,7 @@ try:
except ImportError:
- def use_torch_stream(context: Optional[Any] = None) -> NoReturn:
+ def use_torch_stream(context: Any = None) -> "TorchStreamContext":
"""Raise an informative error when Torch is unavailable."""
raise ImportError("Cannot import torch")
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 825f9cf..9769158 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -16,9 +16,12 @@
# under the License.
"""Testing utilities."""
+from __future__ import annotations
+
from typing import Any, ClassVar
from . import _ffi_api
+from .container import Array, Map
from .core import Object
from .dataclasses import c_class, field
from .registry import register_object
@@ -28,11 +31,18 @@ from .registry import register_object
class TestObjectBase(Object):
"""Test object base class."""
+ v_i64: int
+ v_f64: float
+ v_str: str
+
@register_object("testing.TestIntPair")
class TestIntPair(Object):
"""Test Int Pair."""
+ a: int
+ b: int
+
def __init__(self, a: int, b: int) -> None:
"""Construct the object."""
self.__ffi_init__(a, b)
@@ -42,6 +52,9 @@ class TestIntPair(Object):
class TestObjectDerived(TestObjectBase):
"""Test object derived class."""
+ v_map: Map
+ v_array: Array
+
def create_object(type_key: str, **kwargs: Any) -> Object:
"""Make an object by reflection.
@@ -79,7 +92,7 @@ class _TestCxxClassBase:
not_field_2: ClassVar[int] = 2
def __init__(self, v_i64: int, v_i32: int) -> None:
- self.__ffi_init__(v_i64 + 1, v_i32 + 2)
+ self.__ffi_init__(v_i64 + 1, v_i32 + 2) # type: ignore[attr-defined]
@c_class("testing.TestCxxClassDerived")
@@ -90,5 +103,5 @@ class _TestCxxClassDerived(_TestCxxClassBase):
@c_class("testing.TestCxxClassDerivedDerived")
class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
- v_str: str = field(default_factory=lambda: "default")
- v_bool: bool
+ v_str: str = field(default_factory=lambda: "default") # type:
ignore[assignment]
+ v_bool: bool # type: ignore[misc]
diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py
index 243a319..6efe80f 100644
--- a/python/tvm_ffi/utils/lockfile.py
+++ b/python/tvm_ffi/utils/lockfile.py
@@ -16,10 +16,12 @@
# under the License.
"""Simple cross-platform advisory file lock utilities."""
+from __future__ import annotations
+
import os
import sys
import time
-from typing import Any, Optional
+from typing import Any, Literal
# Platform-specific imports for file locking
if sys.platform == "win32":
@@ -38,9 +40,9 @@ class FileLock:
def __init__(self, lock_file_path: str) -> None:
"""Initialize a file lock using the given lock file path."""
self.lock_file_path = lock_file_path
- self._file_descriptor = None
+ self._file_descriptor: int | None = None
- def __enter__(self) -> "FileLock":
+ def __enter__(self) -> FileLock:
"""Acquire the lock upon entering the context.
This method blocks until the lock is acquired.
@@ -48,12 +50,12 @@ class FileLock:
self.blocking_acquire()
return self
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) ->
Literal[False]:
"""Context manager protocol: release the lock upon exiting the 'with'
block."""
self.release()
return False # Propagate exceptions, if any
- def acquire(self) -> Optional[bool]:
+ def acquire(self) -> bool:
"""Acquire an exclusive, non-blocking lock on the file.
Returns True if the lock was acquired, False otherwise.
@@ -79,9 +81,7 @@ class FileLock:
self._file_descriptor = None
raise RuntimeError(f"An unexpected error occurred: {e}")
- def blocking_acquire(
- self, timeout: Optional[float] = None, poll_interval: float = 0.1
- ) -> Optional[bool]:
+ def blocking_acquire(self, timeout: float | None = None, poll_interval:
float = 0.1) -> bool:
"""Wait until an exclusive lock can be acquired, with an optional
timeout.
Args:
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index da30464..713520f 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -170,7 +170,7 @@ FMT_MAP = {
}
# Files and patterns to skip during header checking
-SKIP_LIST = []
+SKIP_LIST: list[str] = []
def should_skip_file(filepath: str) -> bool:
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index 9d08f9a..c776b20 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -183,8 +183,8 @@ def main() -> None:
cmd = ["git", "ls-files"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
- assert proc.returncode == 0, f"{' '.join(cmd)} errored: {out}"
res = out.decode("utf-8")
+ assert proc.returncode == 0, f"{' '.join(cmd)} errored: {res}"
flist = res.split()
error_list = []
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index f5f28f3..f42e9b3 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -38,10 +38,10 @@ def test_bad_constructor_init_state() -> None:
proper repr code
"""
with pytest.raises(TypeError):
- tvm_ffi.Array(1)
+ tvm_ffi.Array(1) # type: ignore[arg-type]
with pytest.raises(AttributeError):
- tvm_ffi.Map(1)
+ tvm_ffi.Map(1) # type: ignore[arg-type]
def test_array_of_array_map() -> None:
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
index a2fa80e..0050e6c 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -57,7 +57,7 @@ def test_cxx_class_derived_derived() -> None:
def test_cxx_class_derived_derived_default() -> None:
- obj = _TestCxxClassDerivedDerived(123, 456, 4, True)
+ obj = _TestCxxClassDerivedDerived(123, 456, 4, True) # type:
ignore[call-arg,misc]
assert obj.v_i64 == 123
assert obj.v_i32 == 456
assert isinstance(obj.v_f64, float) and obj.v_f64 == 4
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 9441c9f..7a3638b 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import pickle
import pytest
@@ -87,7 +89,7 @@ def test_deive_type_error(dev_type: str, dev_id: int | None)
-> None:
def test_deive_id_error() -> None:
with pytest.raises(TypeError):
- tvm_ffi.device("cpu", "?")
+ tvm_ffi.device("cpu", "?") # type: ignore[arg-type]
def test_device_pickle() -> None:
diff --git a/tests/python/test_error.py b/tests/python/test_error.py
index dda436c..dd94cf3 100644
--- a/tests/python/test_error.py
+++ b/tests/python/test_error.py
@@ -39,9 +39,9 @@ def test_error_from_cxx() -> None:
try:
test_raise_error("ValueError", "error XYZ")
except ValueError as e:
- assert e.__tvm_ffi_error__.kind == "ValueError"
- assert e.__tvm_ffi_error__.message == "error XYZ"
- assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
+ assert e.__tvm_ffi_error__.kind == "ValueError" # type:
ignore[attr-defined]
+ assert e.__tvm_ffi_error__.message == "error XYZ" # type:
ignore[attr-defined]
+ assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 #
type: ignore[attr-defined]
fapply = tvm_ffi.convert(lambda f, *args: f(*args))
@@ -64,17 +64,17 @@ def test_error_from_nested_pyfunc() -> None:
try:
fapply(cxx_test_raise_error, "ValueError", "error XYZ")
except ValueError as e:
- assert e.__tvm_ffi_error__.kind == "ValueError"
- assert e.__tvm_ffi_error__.message == "error XYZ"
- assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
- record_object.append(e.__tvm_ffi_error__)
+ assert e.__tvm_ffi_error__.kind == "ValueError" # type:
ignore[attr-defined]
+ assert e.__tvm_ffi_error__.message == "error XYZ" # type:
ignore[attr-defined]
+ assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
# type: ignore[attr-defined]
+ record_object.append(e.__tvm_ffi_error__) # type:
ignore[attr-defined]
raise e
try:
cxx_test_apply(raise_error)
except ValueError as e:
- backtrace = e.__tvm_ffi_error__.backtrace
- assert e.__tvm_ffi_error__.same_as(record_object[0])
+ backtrace = e.__tvm_ffi_error__.backtrace # type: ignore[attr-defined]
+ assert e.__tvm_ffi_error__.same_as(record_object[0]) # type:
ignore[attr-defined]
assert backtrace.count("TestRaiseError") == 1
# The following lines may fail if debug symbols are missing
try:
@@ -108,7 +108,7 @@ def test_error_traceback_update() -> None:
try:
raise_cxx_error()
except ValueError as e:
- assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1
+ assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1 #
type: ignore[attr-defined]
ffi_error1 = tvm_ffi.convert(e)
ffi_error2 = fecho(e)
assert ffi_error1.backtrace.find("raise_cxx_error") != -1
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index d0afd5f..edc9ffd 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -113,11 +113,11 @@ def test_string_bytes_passing() -> None:
# small bytes
assert fecho(b"hello") == b"hello"
# large bytes
- x = b"hello" * 100
- y = fecho(x)
- assert y == x
- assert y.__tvm_ffi_object__ is not None
- fecho(y) == 1
+ x2 = b"hello" * 100
+ y2 = fecho(x2)
+ assert y2 == x2
+ assert y2.__tvm_ffi_object__ is not None
+ fecho(y2) == 1
def test_nested_container_passing() -> None:
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 5454284..cd46bf5 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -15,12 +15,16 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+from types import ModuleType
import numpy
import pytest
+torch: ModuleType | None
try:
- import torch
+ import torch # type: ignore[no-redef]
except ImportError:
torch = None
@@ -197,6 +201,7 @@ def test_load_inline_cuda() -> None:
@pytest.mark.skipif(torch is None, reason="Requires torch")
def test_load_inline_with_env_tensor_allocator() -> None:
+ assert torch is not None
if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"):
pytest.skip("Torch does not support __c_dlpack_tensor_allocator__")
mod: Module = tvm_ffi.cpp.load_inline(
@@ -241,6 +246,7 @@ def test_load_inline_with_env_tensor_allocator() -> None:
torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
)
def test_load_inline_both() -> None:
+ assert torch is not None
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index ea54adf..aa1a791 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -24,6 +24,7 @@ import tvm_ffi
def test_make_object() -> None:
# with default values
obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase")
+ assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
assert obj0.v_i64 == 10
assert obj0.v_f64 == 10.0
assert obj0.v_str == "hello"
@@ -37,14 +38,16 @@ def test_make_object_via_init() -> None:
def test_method() -> None:
obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
- assert obj0.add_i64(1) == 13
- assert type(obj0).add_i64.__doc__ == "add_i64 method"
- assert type(obj0).v_i64.__doc__ == "i64 field"
+ assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
+ assert obj0.add_i64(1) == 13 # type: ignore[attr-defined]
+ assert type(obj0).add_i64.__doc__ == "add_i64 method" # type:
ignore[attr-defined]
+ assert type(obj0).v_i64.__doc__ == "i64 field" # type:
ignore[attr-defined]
def test_setter() -> None:
# test setter
obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10,
v_str="hello")
+ assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
assert obj0.v_i64 == 10
obj0.v_i64 = 11
assert obj0.v_i64 == 11
@@ -52,10 +55,10 @@ def test_setter() -> None:
assert obj0.v_str == "world"
with pytest.raises(TypeError):
- obj0.v_str = 1
+ obj0.v_str = 1 # type: ignore[assignment]
with pytest.raises(TypeError):
- obj0.v_i64 = "hello"
+ obj0.v_i64 = "hello" # type: ignore[assignment]
def test_derived_object() -> None:
@@ -68,6 +71,7 @@ def test_derived_object() -> None:
obj0 = tvm_ffi.testing.create_object(
"testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array
)
+ assert isinstance(obj0, tvm_ffi.testing.TestObjectDerived)
assert obj0.v_map.same_as(v_map)
assert obj0.v_array.same_as(v_array)
assert obj0.v_i64 == 20
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index cfaf650..9280aab 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -15,12 +15,17 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+from types import ModuleType
+
import pytest
import tvm_ffi
import tvm_ffi.cpp
+torch: ModuleType | None
try:
- import torch
+ import torch # type: ignore[no-redef]
except ImportError:
torch = None
@@ -56,6 +61,7 @@ def test_raw_stream() -> None:
torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
)
def test_torch_stream() -> None:
+ assert torch is not None
mod = gen_check_stream_mod()
device_id = torch.cuda.current_device()
device = tvm_ffi.device("cuda", device_id)
@@ -78,6 +84,7 @@ def test_torch_stream() -> None:
torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
)
def test_torch_current_stream() -> None:
+ assert torch is not None
mod = gen_check_stream_mod()
device_id = torch.cuda.current_device()
device = tvm_ffi.device("cuda", device_id)
@@ -103,6 +110,7 @@ def test_torch_current_stream() -> None:
torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
)
def test_torch_graph() -> None:
+ assert torch is not None
mod = gen_check_stream_mod()
device_id = torch.cuda.current_device()
device = tvm_ffi.device("cuda", device_id)
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 4c2e9a8..186d91b 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -14,10 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from __future__ import annotations
+
+from types import ModuleType
+
import pytest
+torch: ModuleType | None
try:
- import torch
+ import torch # type: ignore[no-redef]
except ImportError:
torch = None
@@ -45,18 +51,19 @@ def test_shape_object() -> None:
assert shape == (10, 8, 4, 2)
fecho = tvm_ffi.convert(lambda x: x)
- shape2 = fecho(shape)
+ shape2: tvm_ffi.Shape = fecho(shape)
assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
assert isinstance(shape2, tvm_ffi.Shape)
assert isinstance(shape2, tuple)
- shape3 = tvm_ffi.convert(shape)
+ shape3: tvm_ffi.Shape = tvm_ffi.convert(shape)
assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
assert isinstance(shape3, tvm_ffi.Shape)
@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not
enabled")
def test_tensor_auto_dlpack() -> None:
+ assert torch is not None
x = torch.arange(128)
fecho = tvm_ffi.get_global_func("testing.echo")
y = fecho(x)
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index 2d9b296..4366b58 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -254,9 +254,9 @@ def
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str)
x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
- x = tvm_ffi.core.DLTensorTestWrapper(x)
- y = tvm_ffi.core.DLTensorTestWrapper(y)
- z = tvm_ffi.core.DLTensorTestWrapper(z)
+ x = tvm_ffi.core.DLTensorTestWrapper(x) # type: ignore[assignment]
+ y = tvm_ffi.core.DLTensorTestWrapper(y) # type: ignore[assignment]
+ z = tvm_ffi.core.DLTensorTestWrapper(z) # type: ignore[assignment]
bench_tvm_ffi_nop_autodlpack(
f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z,
repeat
)