This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 7619669 chore: Switch from mypy to ty (#432)
7619669 is described below
commit 761966953fec7e8ceba0fcc20dc003e39cf2692d
Author: Junru Shao <[email protected]>
AuthorDate: Sat Feb 7 04:12:46 2026 -0800
chore: Switch from mypy to ty (#432)
Astral's `uv`, `ruff`, `ty` are all extremely solid Python tooling
infra. Particularly, this PR migrates `mypy` to `ty`, which is a type
checker + LSP. This is significantly faster and much more accurate.
---
.github/workflows/ci_test.yml | 4 ++
.pre-commit-config.yaml | 16 +++----
CONTRIBUTING.md | 2 +-
examples/cubin_launcher/benchmark_overhead.py | 4 +-
examples/cubin_launcher/example_triton_cubin.py | 4 +-
pyproject.toml | 51 ++++++++++++----------
python/tvm_ffi/__init__.py | 4 +-
python/tvm_ffi/_convert.py | 5 +--
python/tvm_ffi/_optional_torch_c_dlpack.py | 4 +-
python/tvm_ffi/_tensor.py | 4 +-
python/tvm_ffi/container.py | 6 +--
python/tvm_ffi/cpp/nvrtc.py | 2 +-
python/tvm_ffi/dataclasses/_utils.py | 6 +--
python/tvm_ffi/dataclasses/field.py | 8 ++--
python/tvm_ffi/error.py | 6 +--
python/tvm_ffi/libinfo.py | 2 +-
python/tvm_ffi/module.py | 2 +-
python/tvm_ffi/registry.py | 4 +-
python/tvm_ffi/stub/cli.py | 6 +--
python/tvm_ffi/stub/file_utils.py | 4 +-
python/tvm_ffi/testing/testing.py | 4 +-
.../utils/_build_optional_torch_c_dlpack.py | 2 +-
tests/python/test_container.py | 6 +--
tests/python/test_cubin_launcher.py | 4 +-
tests/python/test_dataclasses_c_class.py | 6 +--
tests/python/test_device.py | 2 +-
tests/python/test_dlpack_exchange_api.py | 9 ++--
tests/python/test_error.py | 20 ++++-----
tests/python/test_function.py | 2 +-
tests/python/test_load_inline.py | 7 ++-
tests/python/test_object.py | 32 +++++++-------
tests/python/test_optional_torch_c_dlpack.py | 3 +-
tests/python/test_stream.py | 6 +--
tests/python/test_tensor.py | 4 +-
tests/python/utils/test_embed_cubin.py | 4 +-
tests/python/utils/test_kwargs_wrapper.py | 8 ++--
tests/scripts/benchmark_dlpack.py | 6 +--
tests/scripts/benchmark_kwargs_wrapper.py | 2 +-
38 files changed, 134 insertions(+), 137 deletions(-)
diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml
index 88e7157..440fa40 100644
--- a/.github/workflows/ci_test.yml
+++ b/.github/workflows/ci_test.yml
@@ -63,6 +63,10 @@ jobs:
with:
fetch-depth: 0
fetch-tags: true
+ - name: Set up uv
+ uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 #
v6.7.0
+ - name: Set up Python environment
+ run: uv sync --group dev --no-install-project
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd #
v3.0.1
clang-tidy:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 99fa44e..d0e7aa8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -113,18 +113,14 @@ repos:
args:
- --config
- docs/.rstcheck.cfg
- - repo: https://github.com/pre-commit/mirrors-mypy
- rev: "v1.19.0"
+ - repo: local
hooks:
- - id: mypy
- name: mypy
+ - id: ty
+ name: ty check
+ entry: uvx ty check
+ language: system
pass_filenames: false
- args: []
- additional_dependencies:
- - numpy>=1.22
- - ml-dtypes>=0.1
- - pytest
- - typing-extensions>=4.5
+ types: [python]
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 12d5edc..837df87 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -77,7 +77,7 @@ The pre-commit configuration includes checks for:
- **License headers**: Ensures all files have proper Apache Software
Foundation headers
- **Code formatting**: Runs clang-format (C++), ruff (Python), shfmt (Shell
scripts)
- **Linting**: Runs clang-tidy, ruff, shellcheck, markdownlint, yamllint, and
more
-- **Type checking**: Runs mypy for Python type annotations
+- **Type checking**: Runs ty for Python type annotations
- **File quality**: Checks for trailing whitespace, file sizes, merge
conflicts, etc.
### Troubleshooting
diff --git a/examples/cubin_launcher/benchmark_overhead.py
b/examples/cubin_launcher/benchmark_overhead.py
index 89bfa84..efc82b2 100644
--- a/examples/cubin_launcher/benchmark_overhead.py
+++ b/examples/cubin_launcher/benchmark_overhead.py
@@ -39,8 +39,8 @@ import traceback
from typing import Callable
import torch
-import triton # type: ignore[import-not-found]
-import triton.language as tl # type: ignore[import-not-found]
+import triton
+import triton.language as tl
from tvm_ffi import cpp
from tvm_ffi.module import Module
diff --git a/examples/cubin_launcher/example_triton_cubin.py
b/examples/cubin_launcher/example_triton_cubin.py
index 6654289..b127cd2 100644
--- a/examples/cubin_launcher/example_triton_cubin.py
+++ b/examples/cubin_launcher/example_triton_cubin.py
@@ -35,8 +35,8 @@ import sys
import traceback
import torch
-import triton # type: ignore[import-not-found]
-import triton.language as tl # type: ignore[import-not-found]
+import triton
+import triton.language as tl
from tvm_ffi import cpp
diff --git a/pyproject.toml b/pyproject.toml
index 13ef672..0a096f8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,7 +54,7 @@ test = [
[dependency-groups]
dev = [
"ruff",
- "mypy",
+ "ty",
"clang-format",
"clang-tidy",
"ipdb",
@@ -255,28 +255,33 @@ environment = { MACOSX_DEPLOYMENT_TARGET = "10.14" }
[tool.cibuildwheel.windows]
archs = ["AMD64"]
-[tool.mypy]
-python_version = "3.9"
-show_error_codes = true
-mypy_path = ["python", "examples", "tests/python"]
-files = ["python/tvm_ffi", "examples", "tests/python"]
-namespace_packages = true
-explicit_package_bases = true
-allow_redefinition = true
-exclude = '''(?x)(
- ^\.venv/|
- ^build/|
- ^dist/|
- ^\.mypy_cache/
-)'''
-
-[[tool.mypy.overrides]]
-module = ["torch", "torch.*", "my_ffi_extension", "my_ffi_extension.*"]
-ignore_missing_imports = true
-
-[[tool.mypy.overrides]]
-module = "_pytest.*"
-follow_imports = "skip"
+[tool.ty.environment]
+python-version = "3.9"
+python = ".venv"
+extra-paths = ["python", "examples", "tests/python"]
+
+[tool.ty.src]
+include = ["python/tvm_ffi/**", "examples/**", "tests/python/**"]
+exclude = [".venv/**", "build/**", "dist/**"]
+
+[tool.ty.analysis]
+allowed-unresolved-imports = [
+ "torch",
+ "torch.*",
+ "torch.utils.*",
+ "my_ffi_extension",
+ "my_ffi_extension.*",
+ "_pytest.*",
+ "cupy",
+ "cupy.*",
+ "paddle",
+ "paddle.*",
+ "triton",
+ "triton.*",
+ "cuda.bindings",
+ "cuda.bindings.*",
+ "torch_c_dlpack_ext",
+]
[tool.uv.dependency-groups]
docs = { requires-python = ">=3.13" }
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 8d68b5d..87e6ef2 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -26,7 +26,7 @@
# 2. Python 3.12
# 3. torch 2.9.0
try:
- import torch # type: ignore
+ import torch
except ImportError:
pass
@@ -86,7 +86,7 @@ from ._dtype import (
)
try:
- from ._version import __version__, __version_tuple__ # type:
ignore[import-not-found]
+ from ._version import __version__, __version_tuple__
except ImportError:
__version__ = "0.0.0.dev0"
__version_tuple__ = (0, 0, 0, "dev0", "7d34eb8ab.d20250913")
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index d35ad90..27f4e98 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -25,11 +25,10 @@ from typing import Any
from . import _dtype, container, core
-torch: ModuleType | None = None
try:
- import torch # type: ignore[no-redef]
+ import torch
except ImportError:
- pass
+ torch = None
numpy: ModuleType | None = None
try:
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index dde0550..415e151 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -41,7 +41,7 @@ import warnings
from pathlib import Path
from typing import Any
-logger = logging.getLogger(__name__) # type: ignore
+logger = logging.getLogger(__name__)
def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
@@ -103,7 +103,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
"""Load the torch c dlpack extension."""
try:
- import torch_c_dlpack_ext # type: ignore # noqa: PLC0415, F401
+ import torch_c_dlpack_ext # noqa: PLC0415, F401
if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
return None
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index b9a4954..f81132e 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -73,8 +73,8 @@ class Shape(tuple, PyNativeObject):
def __from_tvm_ffi_object__(cls, obj: Any) -> Shape:
"""Construct from a given tvm object."""
content = _shape_obj_get_py_tuple(obj)
- val: Shape = tuple.__new__(cls, content) # type: ignore[arg-type]
- val._tvm_ffi_cached_object = obj # type: ignore[attr-defined]
+ val: Shape = tuple.__new__(cls, content) # ty:
ignore[invalid-argument-type]
+ val._tvm_ffi_cached_object = obj
return val
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 3f6884f..25cdfba 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -169,7 +169,7 @@ class Array(core.Object, Sequence[T]):
@overload
def __getitem__(self, idx: slice, /) -> list[T]: ...
- def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]:
+ def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: #
ty: ignore[invalid-method-override]
"""Return one element or a list for a slice."""
length = len(self)
result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx)
@@ -227,7 +227,7 @@ class KeysView(KeysViewBase[K]):
if not functor(2):
break
- def __contains__(self, k: object) -> bool:
+ def __contains__(self, k: object) -> bool: # ty:
ignore[invalid-method-override]
return k in self._backend_map
@@ -273,7 +273,7 @@ class ItemsView(ItemsViewBase[K, V]):
if not isinstance(item, tuple) or len(item) != 2:
return False
key, value = item
- actual_value = self._backend_map.get(key, MISSING)
+ actual_value = self._backend_map.get(key, MISSING) # ty:
ignore[invalid-argument-type]
if actual_value is MISSING:
return False
# TODO(@junrus): Is `__eq__` the right method to use here?
diff --git a/python/tvm_ffi/cpp/nvrtc.py b/python/tvm_ffi/cpp/nvrtc.py
index 97033bd..73f76d7 100644
--- a/python/tvm_ffi/cpp/nvrtc.py
+++ b/python/tvm_ffi/cpp/nvrtc.py
@@ -79,7 +79,7 @@ def nvrtc_compile( # noqa: PLR0912, PLR0915
"""
try:
- from cuda.bindings import driver, nvrtc # type:
ignore[import-not-found] # noqa: PLC0415
+ from cuda.bindings import driver, nvrtc # noqa: PLC0415
except ImportError as e:
raise RuntimeError(
"CUDA bindings not available. Install with: pip install
cuda-python"
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
index 7c0afb4..80e7010 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -61,8 +61,8 @@ def type_info_to_cls(
# Allow overriding methods (including from base classes like
Object.__repr__)
# by always adding to attrs, which will be used when creating the new
class
func.__module__ = cls.__module__
- func.__name__ = name
- func.__qualname__ = f"{cls.__qualname__}.{name}"
+ func.__name__ = name # ty: ignore[unresolved-attribute]
+ func.__qualname__ = f"{cls.__qualname__}.{name}" # ty:
ignore[unresolved-attribute]
func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
attrs[name] = func
@@ -75,7 +75,7 @@ def type_info_to_cls(
# Step 4. Create the new class
new_cls = type(cls.__name__, cls_bases, attrs)
new_cls.__module__ = cls.__module__
- new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore
+ new_cls = functools.wraps(cls, updated=())(new_cls)
return cast(Type[_InputClsType], new_cls)
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index a395e50..a03642f 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -22,7 +22,7 @@ from dataclasses import _MISSING_TYPE, MISSING
from typing import Any, Callable, TypeVar, cast
try:
- from dataclasses import KW_ONLY # type: ignore[attr-defined]
+ from dataclasses import KW_ONLY # ty: ignore[unresolved-import]
except ImportError:
# Python < 3.10: define our own KW_ONLY sentinel
class _KW_ONLY_Sentinel:
@@ -68,11 +68,11 @@ class Field:
def field(
*,
- default: _FieldValue | _MISSING_TYPE = MISSING, # type: ignore[assignment]
- default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, #
type: ignore[assignment]
+ default: _FieldValue | _MISSING_TYPE = MISSING,
+ default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
init: bool = True,
repr: bool = True,
- kw_only: bool | _MISSING_TYPE = MISSING, # type: ignore[assignment]
+ kw_only: bool | _MISSING_TYPE = MISSING,
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class`
proxy.
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index 3c4eaf0..2842629 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -76,9 +76,9 @@ class TracebackManager:
tree = ast.parse("_getframe()", filename=filename, mode="eval")
for node in ast.walk(tree):
if hasattr(node, "col_offset"):
- node.col_offset = 0
+ node.col_offset = 0 # ty: ignore[invalid-assignment]
if hasattr(node, "end_col_offset"):
- node.end_col_offset = 0
+ node.end_col_offset = 0 # ty: ignore[invalid-assignment]
# call into get frame, bt changes the context
code_object = compile(tree, filename, "eval")
# replace the function name and line number
@@ -240,7 +240,7 @@ def register_error(
assert isinstance(py_err, MyError)
"""
- if callable(name_or_cls):
+ if isinstance(name_or_cls, type):
cls = name_or_cls
name_or_cls = cls.__name__
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index d1c1bcf..8649a81 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -232,7 +232,7 @@ def _find_library_by_basename(package: str, target_name:
str) -> Path: # noqa:
lib_dll_names = (f"lib{target_name}.so",)
# Use `importlib.metadata` is the most reliable way to find package data
files
- dist: im.PathDistribution = im.distribution(package) # type:
ignore[assignment]
+ dist: im.PathDistribution = im.distribution(package) # ty:
ignore[invalid-assignment]
record = dist.read_text("RECORD") or ""
for line in record.splitlines():
partial_path, *_ = line.split(",")
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 5aac8f0..ad7dfce 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -129,7 +129,7 @@ class Module(core.Object):
The module
"""
- return self.imports_ # type: ignore[return-value]
+ return self.imports_ # ty: ignore[invalid-return-type]
def implements_function(self, name: str, query_imports: bool = False) ->
bool:
"""Return True if the module defines a global function.
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 9ea8e2c..8126a66 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -134,9 +134,9 @@ def register_global_func(
:py:func:`tvm_ffi.remove_global_func`
"""
- if callable(func_name):
+ if not isinstance(func_name, str):
f = func_name
- func_name = f.__name__
+ func_name = f.__name__ # ty: ignore[unresolved-attribute]
if not isinstance(func_name, str):
raise ValueError("expect string function name")
diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py
index 1540cbc..0cfd839 100644
--- a/python/tvm_ffi/stub/cli.py
+++ b/python/tvm_ffi/stub/cli.py
@@ -133,10 +133,10 @@ def _stage_2(
return ret
# Step 0. Find out functions and classes already defined on files.
- defined_func_prefixes: set[str] = { # type: ignore[union-attr]
+ defined_func_prefixes: set[str] = {
code.param[0] for file in files for code in file.code_blocks if
code.kind == "global"
}
- defined_objs: set[str] = { # type: ignore[assignment]
+ defined_objs: set[str] = { # ty: ignore[invalid-assignment]
code.param for file in files for code in file.code_blocks if code.kind
== "object"
} | C.BUILTIN_TYPE_KEYS
@@ -198,7 +198,7 @@ def _stage_3( # noqa: PLR0912
# Stage 1. Collect `tvm-ffi-stubgen(import-object): ...`
for code in file.code_blocks:
if code.kind == "import-object":
- name, type_checking_only, alias = code.param # type: ignore[misc]
+ name, type_checking_only, alias = code.param
imports.append(
ImportItem(
name,
diff --git a/python/tvm_ffi/stub/file_utils.py
b/python/tvm_ffi/stub/file_utils.py
index 055c56d..e7ce609 100644
--- a/python/tvm_ffi/stub/file_utils.py
+++ b/python/tvm_ffi/stub/file_utils.py
@@ -111,7 +111,7 @@ class CodeBlock:
else:
raise ValueError(f"Unknown stub type `{stub}` at line {lineo}")
return CodeBlock(
- kind=kind, # type: ignore[arg-type]
+ kind=kind,
param=param,
lineno_start=lineo,
lineno_end=None,
@@ -275,7 +275,7 @@ def path_walk(
return
# Python 3.12+ - just delegate to `Path.walk`
if hasattr(p, "walk"):
- yield from p.walk( # type: ignore[attr-defined]
+ yield from p.walk( # ty: ignore[call-non-callable]
top_down=top_down,
on_error=on_error,
follow_symlinks=follow_symlinks,
diff --git a/python/tvm_ffi/testing/testing.py
b/python/tvm_ffi/testing/testing.py
index 0ffeb49..cb6a0e1 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -158,7 +158,7 @@ class _TestCxxClassBase:
not_field_2: ClassVar[int] = 2
def __init__(self, v_i64: int, v_i32: int) -> None:
- self.__ffi_init__(v_i64 + 1, v_i32 + 2) # type: ignore[attr-defined]
+ self.__ffi_init__(v_i64 + 1, v_i32 + 2) # ty:
ignore[unresolved-attribute]
@c_class("testing.TestCxxClassDerived")
@@ -170,7 +170,7 @@ class _TestCxxClassDerived(_TestCxxClassBase):
@c_class("testing.TestCxxClassDerivedDerived")
class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
v_str: str = field(default_factory=lambda: "default")
- v_bool: bool # type: ignore[misc] # Suppress: Attributes without a
default cannot follow attributes with one
+ v_bool: bool # ty: ignore[dataclass-field-order] # Required field after
fields with defaults
@c_class("testing.TestCxxInitSubset")
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index fd0aff4..985f191 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -716,7 +716,7 @@ def get_torch_include_paths(build_with_cuda: bool) ->
Sequence[str]:
device_type="cuda" if build_with_cuda else "cpu"
)
else:
- return torch.utils.cpp_extension.include_paths(cuda=build_with_cuda)
# type: ignore[call-arg]
+ return torch.utils.cpp_extension.include_paths(cuda=build_with_cuda)
def main() -> None: # noqa: PLR0912, PLR0915
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index 37f7432..ffedb75 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -50,10 +50,10 @@ def test_bad_constructor_init_state() -> None:
proper repr code
"""
with pytest.raises(TypeError):
- tvm_ffi.Array(1) # type: ignore[arg-type]
+ tvm_ffi.Array(1) # ty: ignore[invalid-argument-type]
with pytest.raises(AttributeError):
- tvm_ffi.Map(1) # type: ignore[arg-type]
+ tvm_ffi.Map(1) # ty: ignore[invalid-argument-type]
def test_array_of_array_map() -> None:
@@ -198,7 +198,7 @@ def test_array_concat(
b: Sequence[int],
c_expected: Sequence[int],
) -> None:
- c_actual = a + b # type: ignore[operator]
+ c_actual = a + b # ty: ignore[unsupported-operator]
assert type(c_actual) is type(c_expected)
assert len(c_actual) == len(c_expected)
assert tuple(c_actual) == tuple(c_expected)
diff --git a/tests/python/test_cubin_launcher.py
b/tests/python/test_cubin_launcher.py
index 014511a..6190670 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -22,13 +22,11 @@ import subprocess
import sys
import tempfile
from pathlib import Path
-from types import ModuleType
import pytest
-torch: ModuleType | None
try:
- import torch # type: ignore[import-not-found,no-redef]
+ import torch
except ImportError:
torch = None
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
index 3a757d0..e3735e4 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -69,7 +69,7 @@ def test_cxx_class_derived_derived() -> None:
def test_cxx_class_derived_derived_default() -> None:
- obj = _TestCxxClassDerivedDerived(123, 456, 4, True) # type:
ignore[call-arg,misc]
+ obj = _TestCxxClassDerivedDerived(123, 456, 4, True) # ty:
ignore[missing-argument]
assert obj.v_i64 == 123
assert obj.v_i32 == 456
assert isinstance(obj.v_f64, float) and obj.v_f64 == 4
@@ -94,7 +94,7 @@ def test_cxx_class_init_subset_defaults() -> None:
def test_cxx_class_init_subset_positional() -> None:
- obj = _TestCxxInitSubset(7) # type: ignore[call-arg]
+ obj = _TestCxxInitSubset(7)
assert obj.required_field == 7
assert obj.optional_field == -1
obj.optional_field = 11
@@ -160,7 +160,7 @@ def test_kw_only_class_level_with_default() -> None:
def test_kw_only_class_level_rejects_positional() -> None:
with pytest.raises(TypeError, match="positional"):
- _TestCxxKwOnly(1, 2, 3, 4) # type: ignore[misc]
+ _TestCxxKwOnly(1, 2, 3, 4) # ty: ignore[missing-argument,
too-many-positional-arguments]
def test_field_kw_only_parameter() -> None:
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 33b48e8..6575ec9 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -95,7 +95,7 @@ def test_deive_type_error(dev_type: str, dev_id: int | None)
-> None:
def test_deive_id_error() -> None:
with pytest.raises(TypeError):
- tvm_ffi.device("cpu", "?") # type: ignore[arg-type]
+ tvm_ffi.device("cpu", "?") # ty: ignore[invalid-argument-type]
def test_device_pickle() -> None:
diff --git a/tests/python/test_dlpack_exchange_api.py
b/tests/python/test_dlpack_exchange_api.py
index 9d1df21..dda1378 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -24,15 +24,15 @@ import sys
import pytest
try:
- import torch # type: ignore[no-redef]
+ import torch
# Import tvm_ffi to load the DLPack exchange API extension
# This sets torch.Tensor.__dlpack_c_exchange_api__
import tvm_ffi
- from torch.utils import cpp_extension # type: ignore
+ from torch.utils import cpp_extension
from tvm_ffi import libinfo
except ImportError:
- torch = None # type: ignore[assignment]
+ torch = None
# Check if DLPack Exchange API is available
_has_dlpack_api = torch is not None and hasattr(torch.Tensor,
"__dlpack_c_exchange_api__")
@@ -46,7 +46,7 @@ def test_dlpack_exchange_api() -> None:
assert torch is not None
assert hasattr(torch.Tensor, "__dlpack_c_exchange_api__")
- api_attr = torch.Tensor.__dlpack_c_exchange_api__ # type:
ignore[attr-defined]
+ api_attr = torch.Tensor.__dlpack_c_exchange_api__
# PyCapsule - extract the pointer as integer
pythonapi = ctypes.pythonapi
# Set restype to c_size_t to get integer directly (avoids c_void_p quirks)
@@ -217,6 +217,7 @@ def test_dlpack_exchange_api() -> None:
@pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API
not available")
def test_from_dlpack_torch() -> None:
# Covers from_dlpack to use fallback fastpath
+ assert torch is not None
tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
tensor_from_dlpack = tvm_ffi.from_dlpack(tensor)
assert tensor_from_dlpack.shape == tensor.shape
diff --git a/tests/python/test_error.py b/tests/python/test_error.py
index 068fb3f..c6c36f9 100644
--- a/tests/python/test_error.py
+++ b/tests/python/test_error.py
@@ -41,9 +41,9 @@ def test_error_from_cxx() -> None:
try:
test_raise_error("ValueError", "error XYZ")
except ValueError as e:
- assert e.__tvm_ffi_error__.kind == "ValueError" # type:
ignore[attr-defined]
- assert e.__tvm_ffi_error__.message == "error XYZ" # type:
ignore[attr-defined]
- assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 #
type: ignore[attr-defined]
+ assert e.__tvm_ffi_error__.kind == "ValueError" # ty:
ignore[unresolved-attribute]
+ assert e.__tvm_ffi_error__.message == "error XYZ" # ty:
ignore[unresolved-attribute]
+ assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 #
ty: ignore[unresolved-attribute]
fapply = tvm_ffi.convert(lambda f, *args: f(*args))
@@ -66,17 +66,17 @@ def test_error_from_nested_pyfunc() -> None:
try:
fapply(cxx_test_raise_error, "ValueError", "error XYZ")
except ValueError as e:
- assert e.__tvm_ffi_error__.kind == "ValueError" # type:
ignore[attr-defined]
- assert e.__tvm_ffi_error__.message == "error XYZ" # type:
ignore[attr-defined]
- assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
# type: ignore[attr-defined]
- record_object.append(e.__tvm_ffi_error__) # type:
ignore[attr-defined]
+ assert e.__tvm_ffi_error__.kind == "ValueError" # ty:
ignore[unresolved-attribute]
+ assert e.__tvm_ffi_error__.message == "error XYZ" # ty:
ignore[unresolved-attribute]
+ assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
# ty: ignore[unresolved-attribute]
+ record_object.append(e.__tvm_ffi_error__) # ty:
ignore[unresolved-attribute]
raise e
try:
cxx_test_apply(raise_error)
except ValueError as e:
- backtrace = e.__tvm_ffi_error__.backtrace # type: ignore[attr-defined]
- assert e.__tvm_ffi_error__.same_as(record_object[0]) # type:
ignore[attr-defined]
+ backtrace = e.__tvm_ffi_error__.backtrace # ty:
ignore[unresolved-attribute]
+ assert e.__tvm_ffi_error__.same_as(record_object[0]) # ty:
ignore[unresolved-attribute]
assert backtrace.count("TestRaiseError") == 1
# The following lines may fail if debug symbols are missing
try:
@@ -110,7 +110,7 @@ def test_error_traceback_update() -> None:
try:
raise_cxx_error()
except ValueError as e:
- assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1 #
type: ignore[attr-defined]
+ assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1 #
ty: ignore[unresolved-attribute]
ffi_error1 = tvm_ffi.convert(e)
ffi_error2 = fecho(e)
assert ffi_error1.backtrace.find("raise_cxx_error") != -1
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 8a494fb..686fc08 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -284,7 +284,7 @@ def test_function_subclass() -> None:
# When subclassing a Cython cdef class and overriding `__init__`,
# special methods like `__call__` may not be inherited automatically.
# This explicit assignment ensures the subclass remains callable.
- __call__ = tvm_ffi.Function.__call__ # type: ignore
+ __call__ = tvm_ffi.Function.__call__
f = tvm_ffi.convert(lambda x: x)
assert isinstance(f, tvm_ffi.Function)
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 77b9f8b..fb4b582 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -17,14 +17,11 @@
from __future__ import annotations
-from types import ModuleType
-
import numpy
import pytest
-torch: ModuleType | None
try:
- import torch # type: ignore[no-redef]
+ import torch
except ImportError:
torch = None
@@ -251,6 +248,7 @@ def test_load_inline_with_env_tensor_allocator() -> None:
When a module returns an object, the object deleter address is part of
the
loaded library. We need to keep the module loaded until the object is
deleted.
"""
+ assert torch is not None
x_cpu = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32,
device="cpu")
# test support for nested container passing
y_cpu = mod.return_add_one({"x": [x_cpu]})
@@ -353,6 +351,7 @@ def test_cuda_memory_alloc_noleak() -> None:
def run_check() -> None:
"""Must run in a separate function to ensure deletion happens before
mod unloads."""
+ assert torch is not None
x = torch.arange(1024 * 1024, dtype=torch.float32, device="cuda")
current_allocated = torch.cuda.memory_allocated()
repeat = 8
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index e49d3ec..3f1554c 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -35,7 +35,7 @@ def test_make_object() -> None:
def test_make_object_via_init() -> None:
- obj0 = tvm_ffi.testing.TestIntPair(1, 2) # type: ignore[call-arg]
+ obj0 = tvm_ffi.testing.TestIntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
assert obj0.a == 1
assert obj0.b == 2
@@ -43,13 +43,13 @@ def test_make_object_via_init() -> None:
def test_method() -> None:
obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
- assert obj0.add_i64(1) == 13 # type: ignore[attr-defined]
- assert type(obj0).add_i64.__doc__ == "add_i64 method" # type:
ignore[attr-defined]
- assert type(obj0).v_i64.__doc__ == "i64 field" # type:
ignore[attr-defined]
+ assert obj0.add_i64(1) == 13
+ assert type(obj0).add_i64.__doc__ == "add_i64 method"
+ assert type(obj0).v_i64.__doc__ == "i64 field"
def test_attribute() -> None:
- obj = tvm_ffi.testing.TestIntPair(3, 4) # type: ignore[call-arg]
+ obj = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
assert obj.a == 3
assert obj.b == 4
assert type(obj).a.__doc__ == "Field `a`"
@@ -76,10 +76,10 @@ def test_setter() -> None:
assert obj0.v_str == "world"
with pytest.raises(TypeError):
- obj0.v_str = 1 # type: ignore[assignment]
+ obj0.v_str = 1 # ty: ignore[invalid-assignment]
with pytest.raises(TypeError):
- obj0.v_i64 = "hello" # type: ignore[assignment]
+ obj0.v_i64 = "hello" # ty: ignore[invalid-assignment]
def test_derived_object() -> None:
@@ -93,8 +93,8 @@ def test_derived_object() -> None:
"testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array
)
assert isinstance(obj0, tvm_ffi.testing.TestObjectDerived)
- assert obj0.v_map.same_as(v_map) # type: ignore[attr-defined]
- assert obj0.v_array.same_as(v_array) # type: ignore[attr-defined]
+ assert obj0.v_map.same_as(v_map) # ty: ignore[unresolved-attribute]
+ assert obj0.v_array.same_as(v_array) # ty: ignore[unresolved-attribute]
assert obj0.v_i64 == 20
assert obj0.v_f64 == 10.0
assert obj0.v_str == "hello"
@@ -129,7 +129,7 @@ def test_opaque_object() -> None:
def test_opaque_type_error() -> None:
obj0 = MyObject("hello")
with pytest.raises(TypeError) as e:
- tvm_ffi.testing.add_one(obj0) # type: ignore[arg-type]
+ tvm_ffi.testing.add_one(obj0) # ty: ignore[invalid-argument-type]
assert (
"Mismatched type on argument #0 when calling: `testing.add_one(0: int)
-> int`. Expected `int` but got `ffi.OpaquePyObject`"
in str(e.value)
@@ -154,18 +154,18 @@ def test_object_protocol() -> None:
def test_unregistered_object_fallback() -> None:
def _check_type(x: Any) -> None:
- type_info: TypeInfo = type(x).__tvm_ffi_type_info__ # type:
ignore[attr-defined]
+ type_info: TypeInfo = type(x).__tvm_ffi_type_info__
assert type_info.type_key == "testing.TestUnregisteredObject"
assert x.v1 == 41
assert x.v2 == 42
- assert x.get_v1_plus_one() == 42 # type: ignore[attr-defined]
- assert x.get_v2_plus_two() == 44 # type: ignore[attr-defined]
+ assert x.get_v1_plus_one() == 42
+ assert x.get_v2_plus_two() == 44
assert type(x).__name__ == "TestUnregisteredObject"
assert type(x).__module__ == "testing"
assert type(x).__qualname__ == "testing.TestUnregisteredObject"
- assert "Auto-generated fallback class" in type(x).__doc__ # type:
ignore[operator]
- assert "Get (v1 + 1) from TestUnregisteredBaseObject" in
type(x).get_v1_plus_one.__doc__ # type: ignore[attr-defined]
- assert "Get (v2 + 2) from TestUnregisteredObject" in
type(x).get_v2_plus_two.__doc__ # type: ignore[attr-defined]
+ assert "Auto-generated fallback class" in type(x).__doc__
+ assert "Get (v1 + 1) from TestUnregisteredBaseObject" in
type(x).get_v1_plus_one.__doc__
+ assert "Get (v2 + 2) from TestUnregisteredObject" in
type(x).get_v2_plus_two.__doc__
obj = tvm_ffi.testing.make_unregistered_object()
_check_type(obj)
diff --git a/tests/python/test_optional_torch_c_dlpack.py
b/tests/python/test_optional_torch_c_dlpack.py
index e8e4100..93d29cc 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -25,7 +25,7 @@ import pytest
try:
import torch
except ImportError:
- torch = None # type: ignore[assignment]
+ torch = None
import tvm_ffi
@@ -35,6 +35,7 @@ IS_WINDOWS = sys.platform.startswith("win")
@pytest.mark.skipif(torch is None, reason="torch is not installed")
def test_build_torch_c_dlpack_extension() -> None:
+ assert torch is not None
build_script = Path(tvm_ffi.__file__).parent / "utils" /
"_build_optional_torch_c_dlpack.py"
args = [
sys.executable,
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index fa5c9b3..c1d9376 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -18,21 +18,19 @@
from __future__ import annotations
import ctypes
-from types import ModuleType
import pytest
import tvm_ffi
import tvm_ffi.cpp
-torch: ModuleType | None
try:
- import torch # type: ignore[no-redef]
+ import torch
except ImportError:
torch = None
try:
- from cuda.bindings import driver as cuda_driver # type:
ignore[import-not-found]
+ from cuda.bindings import driver as cuda_driver
except ImportError:
cuda_driver = None
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 0c45654..d2f75ec 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -17,15 +17,13 @@
from __future__ import annotations
-from types import ModuleType
from typing import Any, NamedTuple, NoReturn
import numpy.typing as npt
import pytest
-torch: ModuleType | None
try:
- import torch # type: ignore[no-redef]
+ import torch
except ImportError:
torch = None
diff --git a/tests/python/utils/test_embed_cubin.py
b/tests/python/utils/test_embed_cubin.py
index 33a8558..6a08cd7 100644
--- a/tests/python/utils/test_embed_cubin.py
+++ b/tests/python/utils/test_embed_cubin.py
@@ -22,14 +22,12 @@ import subprocess
import sys
import tempfile
from pathlib import Path
-from types import ModuleType
import pytest
from tvm_ffi.utils.embed_cubin import embed_cubin
-torch: ModuleType | None
try:
- import torch # type: ignore[import-not-found,no-redef]
+ import torch
except ImportError:
torch = None
diff --git a/tests/python/utils/test_kwargs_wrapper.py
b/tests/python/utils/test_kwargs_wrapper.py
index aa304cd..4c370de 100644
--- a/tests/python/utils/test_kwargs_wrapper.py
+++ b/tests/python/utils/test_kwargs_wrapper.py
@@ -124,7 +124,7 @@ def test_validation_errors() -> None:
# Invalid argument name types
with pytest.raises(TypeError, match="Argument name must be a string"):
- make_kwargs_wrapper(target, ["a", 123]) # type: ignore[list-item]
+ make_kwargs_wrapper(target, ["a", 123]) # ty:
ignore[invalid-argument-type]
# Invalid Python identifiers
with pytest.raises(ValueError, match="not a valid Python identifier"):
@@ -138,7 +138,7 @@ def test_validation_errors() -> None:
# arg_defaults not a tuple
with pytest.raises(TypeError, match="arg_defaults must be a tuple"):
- make_kwargs_wrapper(target, ["a", "b"], arg_defaults=[10]) # type:
ignore[arg-type]
+ make_kwargs_wrapper(target, ["a", "b"], arg_defaults=[10]) # ty:
ignore[invalid-argument-type]
# arg_defaults too long
with pytest.raises(ValueError, match=r"arg_defaults has .* values but
only"):
@@ -190,7 +190,7 @@ def test_wrapper_with_signature() -> None:
# Test metadata preservation when prototype is provided
wrapper_with_metadata = make_kwargs_wrapper_from_signature(target, sig,
source_func)
- assert wrapper_with_metadata.__name__ == "source_func"
+ assert wrapper_with_metadata.__name__ == "source_func" # ty:
ignore[unresolved-attribute]
assert wrapper_with_metadata.__doc__ == "Source function documentation."
# With keyword-only arguments
@@ -322,7 +322,7 @@ def test_metadata_preservation() -> None:
target = lambda *args: sum(args)
wrapper = make_kwargs_wrapper(target, ["x", "y"], arg_defaults=(10,),
prototype=my_function)
- assert wrapper.__name__ == "my_function"
+ assert wrapper.__name__ == "my_function" # ty:
ignore[unresolved-attribute]
assert wrapper.__doc__ == "Document the function."
assert wrapper.__annotations__ == my_function.__annotations__
assert wrapper(5) == 15
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index 4798d00..471a048 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -290,9 +290,9 @@ def
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str)
x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
- x = tvm_ffi.core.DLTensorTestWrapper(x) # type: ignore[assignment]
- y = tvm_ffi.core.DLTensorTestWrapper(y) # type: ignore[assignment]
- z = tvm_ffi.core.DLTensorTestWrapper(z) # type: ignore[assignment]
+ x = tvm_ffi.core.DLTensorTestWrapper(x)
+ y = tvm_ffi.core.DLTensorTestWrapper(y)
+ z = tvm_ffi.core.DLTensorTestWrapper(z)
bench_tvm_ffi_nop_autodlpack(
f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z,
repeat
)
diff --git a/tests/scripts/benchmark_kwargs_wrapper.py
b/tests/scripts/benchmark_kwargs_wrapper.py
index 892d06e..185400c 100644
--- a/tests/scripts/benchmark_kwargs_wrapper.py
+++ b/tests/scripts/benchmark_kwargs_wrapper.py
@@ -28,7 +28,7 @@ def print_speed(name: str, speed: float) -> None:
print(f"{name:<60} {speed} sec/call")
-def target_func(*args: Any) -> None: # type: ignore[no-untyped-def]
+def target_func(*args: Any) -> None:
pass