This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 40e9c83   feat: Enable mypy type checking (#35)
40e9c83 is described below

commit 40e9c83345636b187f39df86a077e0b56683c792
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 22 12:58:15 2025 -0700

     feat: Enable mypy type checking (#35)
    
    With this PR, we are able to enable mypy type checking for the following
    directories:
    - `python/` and `tests/`
    - `examples/inline_module`
    - `examples/packaging`
    - `examples/quick_start`
---
 .pre-commit-config.yaml                            |  27 +++++
 .../packaging/python/my_ffi_extension/__init__.py  |   2 +-
 examples/quick_start/run_example.py                |  18 +---
 pyproject.toml                                     |   4 +
 python/tvm_ffi/_convert.py                         |  11 +-
 python/tvm_ffi/_dtype.py                           |   1 +
 python/tvm_ffi/_ffi_api.pyi                        |  43 ++++++++
 python/tvm_ffi/_optional_torch_c_dlpack.py         |  12 +--
 python/tvm_ffi/_tensor.py                          |  20 ++--
 python/tvm_ffi/access_path.py                      |  27 +++--
 python/tvm_ffi/container.py                        |  10 +-
 python/tvm_ffi/core.pyi                            |  17 +--
 python/tvm_ffi/cpp/load_inline.py                  | 115 +++++++++++----------
 python/tvm_ffi/cython/type_info.pxi                |  11 +-
 python/tvm_ffi/dataclasses/_utils.py               |  23 ++---
 python/tvm_ffi/dataclasses/c_class.py              |  13 ++-
 python/tvm_ffi/dataclasses/field.py                |  11 +-
 python/tvm_ffi/error.py                            |  12 ++-
 python/tvm_ffi/libinfo.py                          |  37 ++++---
 python/tvm_ffi/module.py                           |   4 +-
 python/tvm_ffi/registry.py                         |  62 +++++++----
 python/tvm_ffi/serialization.py                    |   6 +-
 python/tvm_ffi/stream.py                           |  10 +-
 python/tvm_ffi/testing.py                          |  19 +++-
 python/tvm_ffi/utils/lockfile.py                   |  16 +--
 tests/lint/check_asf_header.py                     |   2 +-
 tests/lint/check_file_type.py                      |   2 +-
 tests/python/test_container.py                     |   4 +-
 tests/python/test_dataclasses_c_class.py           |   2 +-
 tests/python/test_device.py                        |   4 +-
 tests/python/test_error.py                         |  20 ++--
 tests/python/test_function.py                      |  10 +-
 tests/python/test_load_inline.py                   |   8 +-
 tests/python/test_object.py                        |  14 ++-
 tests/python/test_stream.py                        |  10 +-
 tests/python/test_tensor.py                        |  13 ++-
 tests/scripts/benchmark_dlpack.py                  |   6 +-
 37 files changed, 395 insertions(+), 231 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8f762dc..329c00c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -83,3 +83,30 @@ repos:
     rev: v0.10.0.1
     hooks:
       - id: shellcheck
+  - repo: https://github.com/pre-commit/mirrors-mypy
+    rev: "v1.17.0"
+    hooks:
+      - id: mypy
+        name: mypy for `python/` and `tests/`
+        additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", 
"pytest", "typing-extensions>=4.5"]
+        args: [--show-error-codes, --python-version=3.9]
+        exclude: ^.*/_ffi_api\.py$
+        files: ^(python/|tests/).*\.py$
+      - id: mypy
+        name: mypy for `examples/inline_module`
+        additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", 
"pytest", "typing-extensions>=4.5"]
+        args: [--show-error-codes, --python-version=3.9]
+        exclude: ^.*/_ffi_api\.py$
+        files: ^examples/inline_module/.*\.py$
+      - id: mypy
+        name: mypy for `examples/packaging`
+        additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", 
"pytest", "typing-extensions>=4.5"]
+        args: [--show-error-codes, --python-version=3.9]
+        exclude: ^.*/_ffi_api\.py$
+        files: ^examples/packaging/.*\.py$
+      - id: mypy
+        name: mypy for `examples/quick_start`
+        additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", 
"pytest", "typing-extensions>=4.5"]
+        args: [--show-error-codes, --python-version=3.9]
+        exclude: ^.*/_ffi_api\.py$
+        files: ^examples/quick_start/.*\.py$
diff --git a/examples/packaging/python/my_ffi_extension/__init__.py 
b/examples/packaging/python/my_ffi_extension/__init__.py
index ae4abfd..0c2b0fd 100644
--- a/examples/packaging/python/my_ffi_extension/__init__.py
+++ b/examples/packaging/python/my_ffi_extension/__init__.py
@@ -51,4 +51,4 @@ def raise_error(msg: str) -> None:
         The error raised by the function.
 
     """
-    return _ffi_api.raise_error(msg)
+    return _ffi_api.raise_error(msg)  # type: ignore[attr-defined]
diff --git a/examples/quick_start/run_example.py 
b/examples/quick_start/run_example.py
index 87c9507..2e2b7f3 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -16,15 +16,9 @@
 # under the License.
 """Quick start script to run tvm-ffi examples from prebuilt libraries."""
 
-import tvm_ffi
-
-try:
-    import torch
-except ImportError:
-    torch = None
-
-
 import numpy
+import torch
+import tvm_ffi
 
 
 def run_add_one_cpu() -> None:
@@ -40,9 +34,6 @@ def run_add_one_cpu() -> None:
     print("numpy.result after add_one(x, y)")
     print(x)
 
-    if torch is None:
-        return
-
     x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
     y = torch.empty_like(x)
     # tvm-ffi automatically handles DLPack compatible tensors
@@ -63,9 +54,6 @@ def run_add_one_c() -> None:
     print("numpy.result after add_one_c(x, y)")
     print(x)
 
-    if torch is None:
-        return
-
     x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
     y = torch.empty_like(x)
     mod.add_one_c(x, y)
@@ -75,7 +63,7 @@ def run_add_one_c() -> None:
 
 def run_add_one_cuda() -> None:
     """Load the add_one_cuda module and call the add_one_cuda function."""
-    if torch is None or not torch.cuda.is_available():
+    if not torch.cuda.is_available():
         return
 
     mod = tvm_ffi.load_module("build/add_one_cuda.so")
diff --git a/pyproject.toml b/pyproject.toml
index d2fd897..58d394a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -213,3 +213,7 @@ environment = { MACOSX_DEPLOYMENT_TARGET = "10.14" }
 
 [tool.cibuildwheel.windows]
 archs = ["AMD64"]
+
+[tool.mypy]
+allow_redefinition = true
+ignore_missing_imports = true
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 6313c8e..05e99da 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -16,20 +16,25 @@
 # under the License.
 """Conversion utilities to bring python objects into ffi values."""
 
+from __future__ import annotations
+
 from numbers import Number
+from types import ModuleType
 from typing import Any
 
 from . import container, core
 
+torch: ModuleType | None = None
 try:
-    import torch
+    import torch  # type: ignore[no-redef]
 except ImportError:
-    torch = None
+    pass
 
+numpy: ModuleType | None = None
 try:
     import numpy
 except ImportError:
-    numpy = None
+    pass
 
 
 def convert(value: Any) -> Any:  # noqa: PLR0911,PLR0912
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index ba1735f..8306593 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -59,6 +59,7 @@ class dtype(str):
     """
 
     __slots__ = ["__tvm_ffi_dtype__"]
+    __tvm_ffi_dtype__: core.DataType
 
     _NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}
 
diff --git a/python/tvm_ffi/_ffi_api.pyi b/python/tvm_ffi/_ffi_api.pyi
new file mode 100644
index 0000000..95059e5
--- /dev/null
+++ b/python/tvm_ffi/_ffi_api.pyi
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI API."""
+
+from typing import Any
+
+def ModuleGetKind(*args: Any) -> Any: ...
+def ModuleImplementsFunction(*args: Any) -> Any: ...
+def ModuleGetFunction(*args: Any) -> Any: ...
+def ModuleImportModule(*args: Any) -> Any: ...
+def ModuleInspectSource(*args: Any) -> Any: ...
+def ModuleGetWriteFormats(*args: Any) -> Any: ...
+def ModuleGetPropertyMask(*args: Any) -> Any: ...
+def ModuleClearImports(*args: Any) -> Any: ...
+def ModuleWriteToFile(*args: Any) -> Any: ...
+def ModuleLoadFromFile(*args: Any) -> Any: ...
+def SystemLib(*args: Any) -> Any: ...
+def Array(*args: Any) -> Any: ...
+def ArrayGetItem(*args: Any) -> Any: ...
+def ArraySize(*args: Any) -> Any: ...
+def MapForwardIterFunctor(*args: Any) -> Any: ...
+def Map(*args: Any) -> Any: ...
+def MapGetItem(*args: Any) -> Any: ...
+def MapCount(*args: Any) -> Any: ...
+def MapSize(*args: Any) -> Any: ...
+def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
+def ToJSONGraphString(*args: Any) -> Any: ...
+def FromJSONGraphString(*args: Any) -> Any: ...
+def Shape(*args: Any) -> Any: ...
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index dd820a9..5be7211 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -31,12 +31,12 @@ subsequent calls will be much faster.
 """
 
 import warnings
-from typing import Any, Optional
+from typing import Any
 
 from . import libinfo
 
 
-def load_torch_c_dlpack_extension() -> Optional[Any]:
+def load_torch_c_dlpack_extension() -> Any:
     """Load the torch c dlpack extension."""
     cpp_source = """
 #include <dlpack/dlpack.h>
@@ -556,9 +556,9 @@ int64_t TorchDLPackTensorAllocatorPtr() {
             extra_include_paths=include_paths,
         )
         # set the dlpack related flags
-        torch.Tensor.__c_dlpack_from_pyobject__ = 
mod.TorchDLPackFromPyObjectPtr()
-        torch.Tensor.__c_dlpack_to_pyobject__ = mod.TorchDLPackToPyObjectPtr()
-        torch.Tensor.__c_dlpack_tensor_allocator__ = 
mod.TorchDLPackTensorAllocatorPtr()
+        setattr(torch.Tensor, "__c_dlpack_from_pyobject__", 
mod.TorchDLPackFromPyObjectPtr())
+        setattr(torch.Tensor, "__c_dlpack_to_pyobject__", 
mod.TorchDLPackToPyObjectPtr())
+        setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", 
mod.TorchDLPackTensorAllocatorPtr())
         return mod
     except ImportError:
         pass
@@ -566,7 +566,7 @@ int64_t TorchDLPackTensorAllocatorPtr() {
         warnings.warn(
             f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator 
will not be enabled."
         )
-        return None
+    return None
 
 
 # keep alive
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index 0cc09f1..0d44994 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -15,11 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 """Tensor related objects and functions."""
+
+from __future__ import annotations
+
 # we name it as _tensor.py to avoid potential future case
 # if we also want to expose a tensor function in the root namespace
-
 from numbers import Integral
-from typing import Any, Optional, Union
+from typing import Any
 
 from . import _ffi_api, core, registry
 from .core import (
@@ -43,23 +45,25 @@ class Shape(tuple, PyNativeObject):
 
     """
 
-    def __new__(cls, content: tuple[int, ...]) -> "Shape":
+    __tvm_ffi_object__: Any
+
+    def __new__(cls, content: tuple[int, ...]) -> Shape:
         if any(not isinstance(x, Integral) for x in content):
             raise ValueError("Shape must be a tuple of integers")
-        val = tuple.__new__(cls, content)
+        val: Shape = tuple.__new__(cls, content)
         val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content)
         return val
 
     # pylint: disable=no-self-argument
-    def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape":
+    def __from_tvm_ffi_object__(cls, obj: Any) -> Shape:
         """Construct from a given tvm object."""
         content = _shape_obj_get_py_tuple(obj)
-        val = tuple.__new__(cls, content)
-        val.__tvm_ffi_object__ = obj
+        val: Shape = tuple.__new__(cls, content)  # type: ignore[arg-type]
+        val.__tvm_ffi_object__ = obj  # type: ignore[attr-defined]
         return val
 
 
-def device(device_type: Union[str, int, DLDeviceType], index: Optional[int] = 
None) -> Device:
+def device(device_type: str | int | DLDeviceType, index: int | None = None) -> 
Device:
     """Construct a TVM FFI device with given device type and index.
 
     Parameters
diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py
index e8aec10..aa52d58 100644
--- a/python/tvm_ffi/access_path.py
+++ b/python/tvm_ffi/access_path.py
@@ -39,11 +39,16 @@ class AccessKind(IntEnum):
 class AccessStep(core.Object):
     """Access step container."""
 
+    kind: AccessKind
+    key: Any
+
 
 @register_object("ffi.reflection.AccessPath")
 class AccessPath(core.Object):
     """Access path container."""
 
+    parent: "AccessPath"
+
     def __init__(self) -> None:
         """Disallow direct construction; use `AccessPath.root()` instead."""
         super().__init__()
@@ -55,19 +60,19 @@ class AccessPath(core.Object):
     @staticmethod
     def root() -> "AccessPath":
         """Create a root access path."""
-        return AccessPath._root()
+        return AccessPath._root()  # type: ignore[attr-defined]
 
     def __eq__(self, other: Any) -> bool:
         """Return whether two access paths are equal."""
         if not isinstance(other, AccessPath):
             return False
-        return self._path_equal(other)
+        return self._path_equal(other)  # type: ignore[attr-defined]
 
     def __ne__(self, other: Any) -> bool:
         """Return whether two access paths are not equal."""
         if not isinstance(other, AccessPath):
             return True
-        return not self._path_equal(other)
+        return not self._path_equal(other)  # type: ignore[attr-defined]
 
     def is_prefix_of(self, other: "AccessPath") -> bool:
         """Check if this access path is a prefix of another access path.
@@ -83,7 +88,7 @@ class AccessPath(core.Object):
             True if this access path is a prefix of the other access path, 
False otherwise
 
         """
-        return self._is_prefix_of(other)
+        return self._is_prefix_of(other)  # type: ignore[attr-defined]
 
     def attr(self, attr_key: str) -> "AccessPath":
         """Create an access path to the attribute of the current object.
@@ -99,7 +104,7 @@ class AccessPath(core.Object):
             The extended access path
 
         """
-        return self._attr(attr_key)
+        return self._attr(attr_key)  # type: ignore[attr-defined]
 
     def attr_missing(self, attr_key: str) -> "AccessPath":
         """Create an access path that indicate an attribute is missing.
@@ -115,7 +120,7 @@ class AccessPath(core.Object):
             The extended access path
 
         """
-        return self._attr_missing(attr_key)
+        return self._attr_missing(attr_key)  # type: ignore[attr-defined]
 
     def array_item(self, index: int) -> "AccessPath":
         """Create an access path to the item of the current array.
@@ -131,7 +136,7 @@ class AccessPath(core.Object):
             The extended access path
 
         """
-        return self._array_item(index)
+        return self._array_item(index)  # type: ignore[attr-defined]
 
     def array_item_missing(self, index: int) -> "AccessPath":
         """Create an access path that indicate an array item is missing.
@@ -147,7 +152,7 @@ class AccessPath(core.Object):
             The extended access path
 
         """
-        return self._array_item_missing(index)
+        return self._array_item_missing(index)  # type: ignore[attr-defined]
 
     def map_item(self, key: Any) -> "AccessPath":
         """Create an access path to the item of the current map.
@@ -163,7 +168,7 @@ class AccessPath(core.Object):
             The extended access path
 
         """
-        return self._map_item(key)
+        return self._map_item(key)  # type: ignore[attr-defined]
 
     def map_item_missing(self, key: Any) -> "AccessPath":
         """Create an access path that indicate a map item is missing.
@@ -179,7 +184,7 @@ class AccessPath(core.Object):
             The extended access path
 
         """
-        return self._map_item_missing(key)
+        return self._map_item_missing(key)  # type: ignore[attr-defined]
 
     def to_steps(self) -> list["AccessStep"]:
         """Convert the access path to a list of access steps.
@@ -190,6 +195,6 @@ class AccessPath(core.Object):
             The list of access steps
 
         """
-        return self._to_steps()
+        return self._to_steps()  # type: ignore[attr-defined]
 
     __hash__ = core.Object.__hash__
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 008dda9..6f29dfd 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -16,6 +16,8 @@
 # under the License.
 """Container classes."""
 
+from __future__ import annotations
+
 import collections.abc
 from collections.abc import Iterator, Mapping, Sequence
 from typing import Any, Callable
@@ -121,7 +123,7 @@ class Array(core.Object, collections.abc.Sequence):
 class KeysView(collections.abc.KeysView):
     """Helper class to return keys view."""
 
-    def __init__(self, backend_map: "Map") -> None:
+    def __init__(self, backend_map: Map) -> None:
         self._backend_map = backend_map
 
     def __len__(self) -> int:
@@ -144,7 +146,7 @@ class KeysView(collections.abc.KeysView):
 class ValuesView(collections.abc.ValuesView):
     """Helper class to return values view."""
 
-    def __init__(self, backend_map: "Map") -> None:
+    def __init__(self, backend_map: Map) -> None:
         self._backend_map = backend_map
 
     def __len__(self) -> int:
@@ -164,7 +166,7 @@ class ValuesView(collections.abc.ValuesView):
 class ItemsView(collections.abc.ItemsView):
     """Helper class to return items view."""
 
-    def __init__(self, backend_map: "Map") -> None:
+    def __init__(self, backend_map: Map) -> None:
         self.backend_map = backend_map
 
     def __len__(self) -> int:
@@ -231,7 +233,7 @@ class Map(core.Object, collections.abc.Mapping):
         """Return a dynamic view of the map's keys."""
         return KeysView(self)
 
-    def values(self) -> "ValuesView":
+    def values(self) -> ValuesView:
         """Return a dynamic view of the map's values."""
         return ValuesView(self)
 
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index cfb3ea9..7c623af 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -19,6 +19,7 @@
 from __future__ import annotations
 
 import types
+from ctypes import c_void_p
 from enum import IntEnum
 from typing import Any, Callable
 
@@ -28,7 +29,6 @@ ERROR_TYPE_TO_NAME: dict[type, str]
 
 _WITH_APPEND_BACKTRACE: Callable[[BaseException, str], BaseException] | None
 _TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType | None], str] | None
-
 # DLPack protocol version (defined in tensor.pxi)
 __dlpack_version__: tuple[int, int]
 
@@ -44,7 +44,7 @@ class Object:
     def __eq__(self, other: Any) -> bool: ...
     def __ne__(self, other: Any) -> bool: ...
     def __hash__(self) -> int: ...
-    def __init_handle_by_constructor__(self, fconstructor: Function, *args: 
Any) -> None: ...
+    def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) -> 
None: ...
     def __ffi_init__(self, *args: Any) -> None:
         """Initialize the instance using the ` __init__` method registered on 
C++ side.
 
@@ -78,9 +78,7 @@ class PyNativeObject:
     """Base class of all TVM objects that also subclass python's builtin 
types."""
 
     __slots__: list[str]
-    def __init_tvm_ffi_object_by_constructor__(
-        self, fconstructor: Function, *args: Any
-    ) -> None: ...
+    def __init_tvm_ffi_object_by_constructor__(self, fconstructor: Any, *args: 
Any) -> None: ...
 
 def _set_class_object(cls: type) -> None: ...
 def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
@@ -101,7 +99,9 @@ class Error(Object):
     def backtrace(self) -> str: ...
 
 def _convert_to_ffi_error(error: BaseException) -> Error: ...
-def _env_set_current_stream(device_type: int, device_id: int, stream: int) -> 
int: ...
+def _env_set_current_stream(
+    device_type: int, device_id: int, stream: int | c_void_p
+) -> int | c_void_p: ...
 
 class DataType:
     """DataType wrapper around DLDataType."""
@@ -121,6 +121,8 @@ class DataType:
     def __str__(self) -> str: ...
 
 def _set_class_dtype(cls: type) -> None: ...
+def _convert_torch_dtype_to_ffi_dtype(torch_dtype: Any) -> DataType: ...
+def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype: Any) -> DataType: ...
 def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes: 
int) -> DataType: ...
 
 class DLDeviceType(IntEnum):
@@ -185,6 +187,9 @@ class Tensor(Object):
         copy: bool | None = None,
     ) -> Any: ...
 
+_CLASS_TENSOR: type[Tensor] = Tensor
+
+def _set_class_tensor(cls: type[Tensor]) -> None: ...
 def from_dlpack(
     ext_tensor: Any, *, require_alignment: int = ..., require_contiguous: bool 
= ...
 ) -> Tensor: ...
diff --git a/python/tvm_ffi/cpp/load_inline.py 
b/python/tvm_ffi/cpp/load_inline.py
index 2c1caf5..d7a5c14 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -16,6 +16,8 @@
 # under the License.
 """Build and load inline C++/CUDA sources into a tvm_ffi Module using Ninja."""
 
+from __future__ import annotations
+
 import functools
 import hashlib
 import os
@@ -24,7 +26,6 @@ import subprocess
 import sys
 from collections.abc import Mapping, Sequence
 from pathlib import Path
-from typing import Optional
 
 from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path, 
find_libtvm_ffi
 from tvm_ffi.module import Module, load_module
@@ -77,7 +78,7 @@ def _maybe_write(path: str, content: str) -> None:
 
 
 @functools.lru_cache
-def _find_cuda_home() -> Optional[str]:
+def _find_cuda_home() -> str:
     """Find the CUDA install path."""
     # Guess #1
     cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
@@ -92,9 +93,10 @@ def _find_cuda_home() -> Optional[str]:
                 cuda_root = Path("C:/Program Files/NVIDIA GPU Computing 
Toolkit/CUDA")
                 cuda_homes = list(cuda_root.glob("v*.*"))
                 if len(cuda_homes) == 0:
-                    cuda_home = ""
-                else:
-                    cuda_home = str(cuda_homes[0])
+                    raise RuntimeError(
+                        "Could not find CUDA installation. Please set 
CUDA_HOME environment variable."
+                    )
+                cuda_home = str(cuda_homes[0])
             else:
                 cuda_home = "/usr/local/cuda"
             if not Path(cuda_home).exists():
@@ -358,17 +360,17 @@ def _decorate_with_tvm_ffi(source: str, functions: 
Mapping[str, str]) -> str:
     return "\n".join(sources)
 
 
-def load_inline(
+def load_inline(  # noqa: PLR0912, PLR0915
     name: str,
     *,
-    cpp_sources: str | None = None,
-    cuda_sources: str | None = None,
-    functions: Sequence[str] | None = None,
+    cpp_sources: Sequence[str] | str | None = None,
+    cuda_sources: Sequence[str] | str | None = None,
+    functions: Mapping[str, str] | Sequence[str] | str | None = None,
     extra_cflags: Sequence[str] | None = None,
     extra_cuda_cflags: Sequence[str] | None = None,
     extra_ldflags: Sequence[str] | None = None,
     extra_include_paths: Sequence[str] | None = None,
-    build_directory: Optional[str] = None,
+    build_directory: str | None = None,
 ) -> Module:
     """Compile and load a C++/CUDA module from inline source code.
 
@@ -481,76 +483,83 @@ def load_inline(
 
     """
     if cpp_sources is None:
-        cpp_sources = []
+        cpp_source_list: list[str] = []
     elif isinstance(cpp_sources, str):
-        cpp_sources = [cpp_sources]
-    cpp_source = "\n".join(cpp_sources)
+        cpp_source_list = [cpp_sources]
+    else:
+        cpp_source_list = list(cpp_sources)
+    cpp_source = "\n".join(cpp_source_list)
+    with_cpp = bool(cpp_source_list)
+    del cpp_source_list
+
     if cuda_sources is None:
-        cuda_sources = []
+        cuda_source_list: list[str] = []
     elif isinstance(cuda_sources, str):
-        cuda_sources = [cuda_sources]
-    cuda_source = "\n".join(cuda_sources)
-    with_cpp = len(cpp_sources) > 0
-    with_cuda = len(cuda_sources) > 0
+        cuda_source_list = [cuda_sources]
+    else:
+        cuda_source_list = list(cuda_sources)
+    cuda_source = "\n".join(cuda_source_list)
+    with_cuda = bool(cuda_source_list)
+    del cuda_source_list
 
-    extra_ldflags = extra_ldflags or []
-    extra_cflags = extra_cflags or []
-    extra_cuda_cflags = extra_cuda_cflags or []
-    extra_include_paths = extra_include_paths or []
+    extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else 
[]
+    extra_cflags_list = list(extra_cflags) if extra_cflags is not None else []
+    extra_cuda_cflags_list = list(extra_cuda_cflags) if extra_cuda_cflags is 
not None else []
+    extra_include_paths_list = list(extra_include_paths) if 
extra_include_paths is not None else []
 
     # add function registration code to sources
-    if isinstance(functions, str):
-        functions = {functions: ""}
-    elif isinstance(functions, Sequence):
-        functions = {name: "" for name in functions}
+    if functions is None:
+        function_map: dict[str, str] = {}
+    elif isinstance(functions, str):
+        function_map = {functions: ""}
+    elif isinstance(functions, Mapping):
+        function_map = dict(functions)
+    else:
+        function_map = {name: "" for name in functions}
 
     if with_cpp:
-        cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
+        cpp_source = _decorate_with_tvm_ffi(cpp_source, function_map)
         cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
     else:
         cpp_source = _decorate_with_tvm_ffi(cpp_source, {})
-        cuda_source = _decorate_with_tvm_ffi(cuda_source, functions)
+        cuda_source = _decorate_with_tvm_ffi(cuda_source, function_map)
 
     # determine the cache dir for the built module
+    build_dir: Path
     if build_directory is None:
-        build_directory = os.environ.get(
-            "TVM_FFI_CACHE_DIR", str(Path("~/.cache/tvm-ffi").expanduser())
-        )
+        cache_dir = os.environ.get("TVM_FFI_CACHE_DIR", 
str(Path("~/.cache/tvm-ffi").expanduser()))
         source_hash: str = _hash_sources(
             cpp_source,
             cuda_source,
-            functions,
-            extra_cflags,
-            extra_cuda_cflags,
-            extra_ldflags,
-            extra_include_paths,
+            function_map,
+            extra_cflags_list,
+            extra_cuda_cflags_list,
+            extra_ldflags_list,
+            extra_include_paths_list,
         )
-        build_dir: str = str(Path(build_directory) / f"{name}_{source_hash}")
+        build_dir = Path(cache_dir).expanduser() / f"{name}_{source_hash}"
     else:
-        build_dir = str(Path(build_directory).resolve())
-    Path(build_dir).mkdir(parents=True, exist_ok=True)
+        build_dir = Path(build_directory).resolve()
+    build_dir.mkdir(parents=True, exist_ok=True)
 
     # generate build.ninja
     ninja_source = _generate_ninja_build(
         name=name,
-        build_dir=build_dir,
+        build_dir=str(build_dir),
         with_cuda=with_cuda,
-        extra_cflags=extra_cflags,
-        extra_cuda_cflags=extra_cuda_cflags,
-        extra_ldflags=extra_ldflags,
-        extra_include_paths=extra_include_paths,
+        extra_cflags=extra_cflags_list,
+        extra_cuda_cflags=extra_cuda_cflags_list,
+        extra_ldflags=extra_ldflags_list,
+        extra_include_paths=extra_include_paths_list,
     )
-
-    with FileLock(str(Path(build_dir) / "lock")):
+    with FileLock(str(build_dir / "lock")):
         # write source files and build.ninja if they do not already exist
-        _maybe_write(str(Path(build_dir) / "main.cpp"), cpp_source)
+        _maybe_write(str(build_dir / "main.cpp"), cpp_source)
         if with_cuda:
-            _maybe_write(str(Path(build_dir) / "cuda.cu"), cuda_source)
-        _maybe_write(str(Path(build_dir) / "build.ninja"), ninja_source)
-
+            _maybe_write(str(build_dir / "cuda.cu"), cuda_source)
+        _maybe_write(str(build_dir / "build.ninja"), ninja_source)
         # build the module
-        _build_ninja(build_dir)
-
+        _build_ninja(str(build_dir))
         # Use appropriate extension based on platform
         ext = ".dll" if IS_WINDOWS else ".so"
-        return load_module(str((Path(build_dir) / f"{name}{ext}").resolve()))
+        return load_module(str((build_dir / f"{name}{ext}").resolve()))
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 4ab9f15..fcd443b 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import dataclasses
+from typing import Optional, Any
 
 
 cdef class FieldGetter:
@@ -62,13 +63,13 @@ class TypeField:
     """Description of a single reflected field on an FFI-backed type."""
 
     name: str
-    doc: str | None
+    doc: Optional[str]
     size: int
     offset: int
     frozen: bool
     getter: FieldGetter
     setter: FieldSetter
-    dataclass_field: object | None = None
+    dataclass_field: Any = None
 
     def __post_init__(self):
         assert self.setter is not None
@@ -96,7 +97,7 @@ class TypeMethod:
     """Description of a single reflected method on an FFI-backed type."""
 
     name: str
-    doc: str | None
+    doc: Optional[str]
     func: object
     is_static: bool
 
@@ -105,9 +106,9 @@ class TypeMethod:
 class TypeInfo:
     """Aggregated type information required to build a proxy class."""
 
-    type_cls: type | None
+    type_cls: Optional[type]
     type_index: int
     type_key: str
     fields: list[TypeField]
     methods: list[TypeMethod]
-    parent_type_info: TypeInfo | None
+    parent_type_info: Optional[TypeInfo]
diff --git a/python/tvm_ffi/dataclasses/_utils.py 
b/python/tvm_ffi/dataclasses/_utils.py
index ef7c7e4..7a28e01 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 import functools
 import inspect
 from dataclasses import MISSING
-from typing import Any, Callable, NamedTuple, TypeVar
+from typing import Any, Callable, NamedTuple, TypeVar, cast
 
 from ..core import (
     Object,
@@ -68,7 +68,7 @@ def type_info_to_cls(
         attrs[field.name] = field.as_property(cls)
 
     # Step 3. Add methods
-    def _add_method(name: str, func: Callable) -> None:
+    def _add_method(name: str, func: Callable[..., Any]) -> None:
         if name == "__ffi_init__":
             name = "__c_ffi_init__"
         if name in attrs:  # already defined
@@ -80,9 +80,9 @@ def type_info_to_cls(
         attrs[name] = func
         setattr(cls, name, func)
 
-    for name, method in methods.items():
-        if method is not None:
-            _add_method(name, method)
+    for name, method_impl in methods.items():
+        if method_impl is not None:
+            _add_method(name, method_impl)
     for method in type_info.methods:
         _add_method(method.name, method.func)
 
@@ -90,7 +90,7 @@ def type_info_to_cls(
     new_cls = type(cls.__name__, cls_bases, attrs)
     new_cls.__module__ = cls.__module__
     new_cls = functools.wraps(cls, updated=())(new_cls)  # type: ignore
-    return new_cls
+    return cast(type[_InputClsType], new_cls)
 
 
 def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
@@ -123,15 +123,12 @@ def method_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:  #
 
         fn: Callable[[], Any]
 
-    fields: list[TypeInfo] = []
-    cur_type_info = type_info
-    while True:
+    fields: list[TypeField] = []
+    cur_type_info: TypeInfo | None = type_info
+    while cur_type_info is not None:
         fields.extend(reversed(cur_type_info.fields))
         cur_type_info = cur_type_info.parent_type_info
-        if cur_type_info is None:
-            break
     fields.reverse()
-    del cur_type_info
 
     annotations: dict[str, Any] = {"return": None}
     # Step 1. Split the parameters into two groups to ensure that
@@ -187,7 +184,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:  #
     else:
         raise ValueError(f"Cannot find constructor method: 
`{type_info.type_key}.__ffi_init__`")
 
-    def __init__(self: type, *args: Any, **kwargs: Any) -> None:
+    def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
         e = None
         try:
             args = bind_args(*args, **kwargs)
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index 7507b76..700f287 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -28,18 +28,21 @@ from dataclasses import InitVar
 from typing import ClassVar, TypeVar, get_origin, get_type_hints
 
 from ..core import TypeField, TypeInfo
-from . import _utils, field
+from . import _utils
+from .field import Field, field
 
 try:
-    from typing import dataclass_transform
+    from typing_extensions import dataclass_transform  # type: 
ignore[attr-defined]
 except ImportError:
-    from typing_extensions import dataclass_transform
+    from typing import dataclass_transform  # type: 
ignore[no-redef,attr-defined]
+except ImportError:
+    pass
 
 
 _InputClsType = TypeVar("_InputClsType")
 
 
-@dataclass_transform(field_specifiers=(field.field, field.Field))
+@dataclass_transform(field_specifiers=(field, Field))
 def c_class(
     type_key: str, init: bool = True
 ) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
@@ -157,7 +160,7 @@ def _inspect_c_class_fields(type_cls: type, type_info: 
TypeInfo) -> list[TypeFie
     for field_name, _field_ty_py in type_hints_py.items():
         if field_name.startswith("__tvm_ffi"):  # TVM's private fields - skip
             continue
-        type_field: TypeField = type_fields_cxx.pop(field_name, None)
+        type_field = type_fields_cxx.pop(field_name, None)
         if type_field is None:
             raise ValueError(
                 f"Extraneous field `{type_cls}.{field_name}`. Defined in 
Python but not in C++"
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index 00170e5..f1a582e 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -18,11 +18,10 @@
 
 from __future__ import annotations
 
-from dataclasses import MISSING, dataclass
+from dataclasses import MISSING
 from typing import Any, Callable
 
 
-@dataclass(kw_only=True)
 class Field:
     """(Experimental) Descriptor placeholder returned by 
:func:`tvm_ffi.dataclasses.field`.
 
@@ -36,8 +35,12 @@ class Field:
     way the decorator understands.
     """
 
-    name: str | None = None
-    default_factory: Callable[[], Any]
+    __slots__ = ("default_factory", "name")
+
+    def __init__(self, *, name: str | None = None, default_factory: 
Callable[[], Any]) -> None:
+        """Do not call directly; use :func:`field` instead."""
+        self.name = name
+        self.default_factory = default_factory
 
 
 def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field:
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index f29fc90..fd0bf2b 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -17,11 +17,13 @@
 # pylint: disable=invalid-name
 """Error handling."""
 
+from __future__ import annotations
+
 import ast
 import re
 import sys
 import types
-from typing import Any, Optional
+from typing import Any
 
 from . import core
 
@@ -60,7 +62,7 @@ class TracebackManager:
 
     def __init__(self) -> None:
         """Initialize the traceback manager and its cache."""
-        self._code_cache = {}
+        self._code_cache: dict[tuple[str, int, str], types.CodeType] = {}
 
     def _get_cached_code_object(self, filename: str, lineno: int, func: str) 
-> types.CodeType:
         # Hack to create a code object that points to the correct
@@ -95,7 +97,7 @@ class TracebackManager:
 
     def append_traceback(
         self,
-        tb: Optional[types.TracebackType],
+        tb: types.TracebackType | None,
         filename: str,
         lineno: int,
         func: str,
@@ -134,7 +136,7 @@ def _with_append_backtrace(py_error: BaseException, 
backtrace: str) -> BaseExcep
     return py_error.with_traceback(tb)
 
 
-def _traceback_to_backtrace_str(tb: Optional[types.TracebackType]) -> str:
+def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str:
     """Convert the traceback to a string."""
     lines = []
     while tb is not None:
@@ -155,7 +157,7 @@ core._TRACEBACK_TO_BACKTRACE_STR = 
_traceback_to_backtrace_str
 
 def register_error(
     name_or_cls: str | type | None = None,
-    cls: Optional[type] = None,
+    cls: type | None = None,
 ) -> Any:
     """Register an error class so it can be recognized by the ffi error 
handler.
 
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index 382690b..8d92df3 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -46,25 +46,24 @@ def split_env_var(env_var: str, split: str) -> list[str]:
 def get_dll_directories() -> list[str]:
     """Get the possible dll directories."""
     ffi_dir = Path(__file__).expanduser().resolve().parent
-    dll_path = [ffi_dir / "lib"]
-    dll_path += [ffi_dir / ".." / ".." / "build" / "lib"]
+    dll_path: list[Path] = [ffi_dir / "lib"]
+    dll_path.append(ffi_dir / ".." / ".." / "build" / "lib")
     # in source build from parent if needed
-    dll_path += [ffi_dir / ".." / ".." / ".." / "build" / "lib"]
-
+    dll_path.append(ffi_dir / ".." / ".." / ".." / "build" / "lib")
     if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
-        dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":"))
-        dll_path.extend(split_env_var("PATH", ":"))
+        dll_path.extend(Path(p) for p in split_env_var("LD_LIBRARY_PATH", ":"))
+        dll_path.extend(Path(p) for p in split_env_var("PATH", ":"))
     elif sys.platform.startswith("darwin"):
-        dll_path.extend(split_env_var("DYLD_LIBRARY_PATH", ":"))
-        dll_path.extend(split_env_var("PATH", ":"))
+        dll_path.extend(Path(p) for p in split_env_var("DYLD_LIBRARY_PATH", 
":"))
+        dll_path.extend(Path(p) for p in split_env_var("PATH", ":"))
     elif sys.platform.startswith("win32"):
-        dll_path.extend(split_env_var("PATH", ";"))
-    return [str(Path(x).resolve()) for x in dll_path if Path(x).is_dir()]
+        dll_path.extend(Path(p) for p in split_env_var("PATH", ";"))
+    return [str(path.resolve()) for path in dll_path if path.is_dir()]
 
 
 def find_libtvm_ffi() -> str:
     """Find libtvm_ffi."""
-    dll_path = get_dll_directories()
+    dll_path = [Path(p) for p in get_dll_directories()]
     if sys.platform.startswith("win32"):
         lib_dll_names = ["tvm_ffi.dll"]
     elif sys.platform.startswith("darwin"):
@@ -72,14 +71,18 @@ def find_libtvm_ffi() -> str:
     else:
         lib_dll_names = ["libtvm_ffi.so"]
 
-    name = lib_dll_names
-    lib_dll_path = [str(Path(p) / name) for name in lib_dll_names for p in 
dll_path]
-    lib_found = [p for p in lib_dll_path if Path(p).exists() and 
Path(p).is_file()]
+    lib_dll_path = [p / name for name in lib_dll_names for p in dll_path]
+    lib_found = [p for p in lib_dll_path if p.exists() and p.is_file()]
 
     if not lib_found:
-        raise RuntimeError(f"Cannot find library: {name}\nList of 
candidates:\n{lib_dll_path}")
-
-    return lib_found[0]
+        candidate_list = "\n".join(str(p) for p in lib_dll_path)
+        raise RuntimeError(
+            "Cannot find library: {}\nList of candidates:\n{}".format(
+                ", ".join(lib_dll_names), candidate_list
+            )
+        )
+
+    return str(lib_found[0])
 
 
 def find_source_path() -> str:
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index acdc11e..768463d 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -73,7 +73,7 @@ class Module(core.Object):
             The module
 
         """
-        return self.imports_
+        return self.imports_  # type: ignore[return-value]
 
     def implements_function(self, name: str, query_imports: bool = False) -> 
bool:
         """Return True if the module defines a global function.
@@ -255,7 +255,7 @@ def system_lib(symbol_prefix: str = "") -> Module:
 
     Parameters
     ----------
-    symbol_prefix: Optional[str]
+    symbol_prefix: str = ""
         Optional symbol prefix that can be used for search. When we lookup a 
symbol
         symbol_prefix + name will first be searched, then the name without 
symbol_prefix.
 
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 6bf08f6..3ef4039 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -16,8 +16,10 @@
 # under the License.
 """FFI registry to register function and objects."""
 
+from __future__ import annotations
+
 import sys
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Literal, overload
 
 from . import core
 from .core import TypeInfo
@@ -26,7 +28,7 @@ from .core import TypeInfo
 _SKIP_UNKNOWN_OBJECTS = False
 
 
-def register_object(type_key: str | type | None = None) -> Any:
+def register_object(type_key: str | type | None = None) -> Callable[[type], 
type] | type:
     """Register object type.
 
     Parameters
@@ -46,9 +48,8 @@ def register_object(type_key: str | type | None = None) -> 
Any:
           pass
 
     """
-    object_name = type_key if isinstance(type_key, str) else type_key.__name__
 
-    def register(cls: type) -> type:
+    def _register(cls: type, object_name: str) -> type:
         """Register the object type with the FFI core."""
         type_index = core._object_type_key_to_index(object_name)
         if type_index is None:
@@ -60,14 +61,25 @@ def register_object(type_key: str | type | None = None) -> 
Any:
         return cls
 
     if isinstance(type_key, str):
-        return register
 
-    return register(type_key)
+        def _decorator_with_name(cls: type) -> type:
+            return _register(cls, type_key)
+
+        return _decorator_with_name
+
+    def _decorator_default(cls: type) -> type:
+        return _register(cls, cls.__name__)
+
+    if type_key is None:
+        return _decorator_default
+    if isinstance(type_key, type):
+        return _decorator_default(type_key)
+    raise TypeError("type_key must be a string, type, or None")
 
 
 def register_global_func(
     func_name: str | Callable[..., Any],
-    f: Optional[Callable[..., Any]] = None,
+    f: Callable[..., Any] | None = None,
     override: bool = False,
 ) -> Any:
     """Register global function.
@@ -124,12 +136,20 @@ def register_global_func(
         """Register the global function with the FFI core."""
         return core._register_global_func(func_name, myf, override)
 
-    if f:
+    if f is not None:
         return register(f)
     return register
 
 
-def get_global_func(name: str, allow_missing: bool = False) -> 
Optional[core.Function]:
+@overload
+def get_global_func(name: str, allow_missing: Literal[True]) -> core.Function 
| None: ...
+
+
+@overload
+def get_global_func(name: str, allow_missing: Literal[False] = False) -> 
core.Function: ...
+
+
+def get_global_func(name: str, allow_missing: bool = False) -> core.Function | 
None:
     """Get a global function by name.
 
     Parameters
@@ -179,7 +199,7 @@ def remove_global_func(name: str) -> None:
     get_global_func("ffi.FunctionRemoveGlobal")(name)
 
 
-def init_ffi_api(namespace: str, target_module_name: Optional[str] = None) -> 
None:
+def init_ffi_api(namespace: str, target_module_name: str | None = None) -> 
None:
     """Initialize register ffi api  functions into a given module.
 
     Parameters
@@ -225,8 +245,8 @@ def init_ffi_api(namespace: str, target_module_name: 
Optional[str] = None) -> No
             continue
 
         f = get_global_func(name)
-        f.__name__ = fname
-        setattr(target_module, f.__name__, f)
+        setattr(f, "__name__", fname)
+        setattr(target_module, fname, f)
 
 
 def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., 
Any]:
@@ -253,16 +273,16 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) 
-> type:
         doc = method.doc if method.doc else None
         method_func = method.func
         if method.is_static:
-            method_pyfunc = staticmethod(method_func)
+            if doc is not None:
+                method_func.__doc__ = doc
+            method_func.__name__ = name
+            method_pyfunc: Any = staticmethod(method_func)
         else:
-            # must call into another method instead of direct capture
-            # to avoid the same method_func variable being used
-            # across multiple loop iterations
-            method_pyfunc = _member_method_wrapper(method_func)
-
-        if doc is not None:
-            method_pyfunc.__doc__ = doc
-        method_pyfunc.__name__ = name
+            wrapped_func = _member_method_wrapper(method_func)
+            if doc is not None:
+                wrapped_func.__doc__ = doc
+            wrapped_func.__name__ = name
+            method_pyfunc = wrapped_func
 
         if hasattr(type_cls, name):
             # skip already defined attributes
diff --git a/python/tvm_ffi/serialization.py b/python/tvm_ffi/serialization.py
index 2bc0d14..6960eda 100644
--- a/python/tvm_ffi/serialization.py
+++ b/python/tvm_ffi/serialization.py
@@ -16,12 +16,14 @@
 # under the License.
 """Serialization related utilities to enable some object can be pickled."""
 
-from typing import Any, Optional
+from __future__ import annotations
+
+from typing import Any
 
 from . import _ffi_api
 
 
-def to_json_graph_str(obj: Any, metadata: Optional[dict] = None) -> str:
+def to_json_graph_str(obj: Any, metadata: dict[str, Any] | None = None) -> str:
     """Dump an object to a JSON graph string.
 
     The JSON graph string is a string representation of of the object
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 3ce9d94..7f2dde5 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -18,7 +18,7 @@
 """Stream context."""
 
 from ctypes import c_void_p
-from typing import Any, NoReturn, Optional, Union
+from typing import Any, Union
 
 from . import core
 from ._tensor import device
@@ -72,7 +72,7 @@ try:
     class TorchStreamContext:
         """Context manager that syncs Torch and FFI stream contexts."""
 
-        def __init__(self, context: Optional[Any]) -> None:
+        def __init__(self, context: Any) -> None:
             """Initialize with an optional Torch stream/graph context 
wrapper."""
             self.torch_context = context
 
@@ -93,14 +93,14 @@ try:
                 self.torch_context.__exit__(*args)
             self.ffi_context.__exit__(*args)
 
-    def use_torch_stream(context: Optional[Any] = None) -> 
"TorchStreamContext":
+    def use_torch_stream(context: Any = None) -> "TorchStreamContext":
         """Create an FFI stream context with a Torch stream or graph.
 
         cuda graph or current stream if `None` provided.
 
         Parameters
         ----------
-        context : Optional[Any]
+        context : Any = None
             The wrapped torch stream or cuda graph.
 
         Returns
@@ -129,7 +129,7 @@ try:
 
 except ImportError:
 
-    def use_torch_stream(context: Optional[Any] = None) -> NoReturn:
+    def use_torch_stream(context: Any = None) -> "TorchStreamContext":
         """Raise an informative error when Torch is unavailable."""
         raise ImportError("Cannot import torch")
 
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 825f9cf..9769158 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -16,9 +16,12 @@
 # under the License.
 """Testing utilities."""
 
+from __future__ import annotations
+
 from typing import Any, ClassVar
 
 from . import _ffi_api
+from .container import Array, Map
 from .core import Object
 from .dataclasses import c_class, field
 from .registry import register_object
@@ -28,11 +31,18 @@ from .registry import register_object
 class TestObjectBase(Object):
     """Test object base class."""
 
+    v_i64: int
+    v_f64: float
+    v_str: str
+
 
 @register_object("testing.TestIntPair")
 class TestIntPair(Object):
     """Test Int Pair."""
 
+    a: int
+    b: int
+
     def __init__(self, a: int, b: int) -> None:
         """Construct the object."""
         self.__ffi_init__(a, b)
@@ -42,6 +52,9 @@ class TestIntPair(Object):
 class TestObjectDerived(TestObjectBase):
     """Test object derived class."""
 
+    v_map: Map
+    v_array: Array
+
 
 def create_object(type_key: str, **kwargs: Any) -> Object:
     """Make an object by reflection.
@@ -79,7 +92,7 @@ class _TestCxxClassBase:
     not_field_2: ClassVar[int] = 2
 
     def __init__(self, v_i64: int, v_i32: int) -> None:
-        self.__ffi_init__(v_i64 + 1, v_i32 + 2)
+        self.__ffi_init__(v_i64 + 1, v_i32 + 2)  # type: ignore[attr-defined]
 
 
 @c_class("testing.TestCxxClassDerived")
@@ -90,5 +103,5 @@ class _TestCxxClassDerived(_TestCxxClassBase):
 
 @c_class("testing.TestCxxClassDerivedDerived")
 class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
-    v_str: str = field(default_factory=lambda: "default")
-    v_bool: bool
+    v_str: str = field(default_factory=lambda: "default")  # type: 
ignore[assignment]
+    v_bool: bool  # type: ignore[misc]
diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py
index 243a319..6efe80f 100644
--- a/python/tvm_ffi/utils/lockfile.py
+++ b/python/tvm_ffi/utils/lockfile.py
@@ -16,10 +16,12 @@
 # under the License.
 """Simple cross-platform advisory file lock utilities."""
 
+from __future__ import annotations
+
 import os
 import sys
 import time
-from typing import Any, Optional
+from typing import Any, Literal
 
 # Platform-specific imports for file locking
 if sys.platform == "win32":
@@ -38,9 +40,9 @@ class FileLock:
     def __init__(self, lock_file_path: str) -> None:
         """Initialize a file lock using the given lock file path."""
         self.lock_file_path = lock_file_path
-        self._file_descriptor = None
+        self._file_descriptor: int | None = None
 
-    def __enter__(self) -> "FileLock":
+    def __enter__(self) -> FileLock:
         """Acquire the lock upon entering the context.
 
         This method blocks until the lock is acquired.
@@ -48,12 +50,12 @@ class FileLock:
         self.blocking_acquire()
         return self
 
-    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> 
Literal[False]:
         """Context manager protocol: release the lock upon exiting the 'with' 
block."""
         self.release()
         return False  # Propagate exceptions, if any
 
-    def acquire(self) -> Optional[bool]:
+    def acquire(self) -> bool:
         """Acquire an exclusive, non-blocking lock on the file.
 
         Returns True if the lock was acquired, False otherwise.
@@ -79,9 +81,7 @@ class FileLock:
                 self._file_descriptor = None
             raise RuntimeError(f"An unexpected error occurred: {e}")
 
-    def blocking_acquire(
-        self, timeout: Optional[float] = None, poll_interval: float = 0.1
-    ) -> Optional[bool]:
+    def blocking_acquire(self, timeout: float | None = None, poll_interval: 
float = 0.1) -> bool:
         """Wait until an exclusive lock can be acquired, with an optional 
timeout.
 
         Args:
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index da30464..713520f 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -170,7 +170,7 @@ FMT_MAP = {
 }
 
 # Files and patterns to skip during header checking
-SKIP_LIST = []
+SKIP_LIST: list[str] = []
 
 
 def should_skip_file(filepath: str) -> bool:
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index 9d08f9a..c776b20 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -183,8 +183,8 @@ def main() -> None:
     cmd = ["git", "ls-files"]
     proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
-    assert proc.returncode == 0, f"{' '.join(cmd)} errored: {out}"
     res = out.decode("utf-8")
+    assert proc.returncode == 0, f"{' '.join(cmd)} errored: {res}"
     flist = res.split()
     error_list = []
 
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index f5f28f3..f42e9b3 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -38,10 +38,10 @@ def test_bad_constructor_init_state() -> None:
     proper repr code
     """
     with pytest.raises(TypeError):
-        tvm_ffi.Array(1)
+        tvm_ffi.Array(1)  # type: ignore[arg-type]
 
     with pytest.raises(AttributeError):
-        tvm_ffi.Map(1)
+        tvm_ffi.Map(1)  # type: ignore[arg-type]
 
 
 def test_array_of_array_map() -> None:
diff --git a/tests/python/test_dataclasses_c_class.py 
b/tests/python/test_dataclasses_c_class.py
index a2fa80e..0050e6c 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -57,7 +57,7 @@ def test_cxx_class_derived_derived() -> None:
 
 
 def test_cxx_class_derived_derived_default() -> None:
-    obj = _TestCxxClassDerivedDerived(123, 456, 4, True)
+    obj = _TestCxxClassDerivedDerived(123, 456, 4, True)  # type: 
ignore[call-arg,misc]
     assert obj.v_i64 == 123
     assert obj.v_i32 == 456
     assert isinstance(obj.v_f64, float) and obj.v_f64 == 4
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 9441c9f..7a3638b 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import annotations
+
 import pickle
 
 import pytest
@@ -87,7 +89,7 @@ def test_deive_type_error(dev_type: str, dev_id: int | None) 
-> None:
 
 def test_deive_id_error() -> None:
     with pytest.raises(TypeError):
-        tvm_ffi.device("cpu", "?")
+        tvm_ffi.device("cpu", "?")  # type: ignore[arg-type]
 
 
 def test_device_pickle() -> None:
diff --git a/tests/python/test_error.py b/tests/python/test_error.py
index dda436c..dd94cf3 100644
--- a/tests/python/test_error.py
+++ b/tests/python/test_error.py
@@ -39,9 +39,9 @@ def test_error_from_cxx() -> None:
     try:
         test_raise_error("ValueError", "error XYZ")
     except ValueError as e:
-        assert e.__tvm_ffi_error__.kind == "ValueError"
-        assert e.__tvm_ffi_error__.message == "error XYZ"
-        assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
+        assert e.__tvm_ffi_error__.kind == "ValueError"  # type: 
ignore[attr-defined]
+        assert e.__tvm_ffi_error__.message == "error XYZ"  # type: 
ignore[attr-defined]
+        assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1  # 
type: ignore[attr-defined]
 
     fapply = tvm_ffi.convert(lambda f, *args: f(*args))
 
@@ -64,17 +64,17 @@ def test_error_from_nested_pyfunc() -> None:
         try:
             fapply(cxx_test_raise_error, "ValueError", "error XYZ")
         except ValueError as e:
-            assert e.__tvm_ffi_error__.kind == "ValueError"
-            assert e.__tvm_ffi_error__.message == "error XYZ"
-            assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1
-            record_object.append(e.__tvm_ffi_error__)
+            assert e.__tvm_ffi_error__.kind == "ValueError"  # type: 
ignore[attr-defined]
+            assert e.__tvm_ffi_error__.message == "error XYZ"  # type: 
ignore[attr-defined]
+            assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1  
# type: ignore[attr-defined]
+            record_object.append(e.__tvm_ffi_error__)  # type: 
ignore[attr-defined]
             raise e
 
     try:
         cxx_test_apply(raise_error)
     except ValueError as e:
-        backtrace = e.__tvm_ffi_error__.backtrace
-        assert e.__tvm_ffi_error__.same_as(record_object[0])
+        backtrace = e.__tvm_ffi_error__.backtrace  # type: ignore[attr-defined]
+        assert e.__tvm_ffi_error__.same_as(record_object[0])  # type: 
ignore[attr-defined]
         assert backtrace.count("TestRaiseError") == 1
         # The following lines may fail if debug symbols are missing
         try:
@@ -108,7 +108,7 @@ def test_error_traceback_update() -> None:
     try:
         raise_cxx_error()
     except ValueError as e:
-        assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1
+        assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1  # 
type: ignore[attr-defined]
         ffi_error1 = tvm_ffi.convert(e)
         ffi_error2 = fecho(e)
         assert ffi_error1.backtrace.find("raise_cxx_error") != -1
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index d0afd5f..edc9ffd 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -113,11 +113,11 @@ def test_string_bytes_passing() -> None:
     # small bytes
     assert fecho(b"hello") == b"hello"
     # large bytes
-    x = b"hello" * 100
-    y = fecho(x)
-    assert y == x
-    assert y.__tvm_ffi_object__ is not None
-    fecho(y) == 1
+    x2 = b"hello" * 100
+    y2 = fecho(x2)
+    assert y2 == x2
+    assert y2.__tvm_ffi_object__ is not None
+    fecho(y2) == 1
 
 
 def test_nested_container_passing() -> None:
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 5454284..cd46bf5 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -15,12 +15,16 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import annotations
+
+from types import ModuleType
 
 import numpy
 import pytest
 
+torch: ModuleType | None
 try:
-    import torch
+    import torch  # type: ignore[no-redef]
 except ImportError:
     torch = None
 
@@ -197,6 +201,7 @@ def test_load_inline_cuda() -> None:
 
 @pytest.mark.skipif(torch is None, reason="Requires torch")
 def test_load_inline_with_env_tensor_allocator() -> None:
+    assert torch is not None
     if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"):
         pytest.skip("Torch does not support __c_dlpack_tensor_allocator__")
     mod: Module = tvm_ffi.cpp.load_inline(
@@ -241,6 +246,7 @@ def test_load_inline_with_env_tensor_allocator() -> None:
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
 def test_load_inline_both() -> None:
+    assert torch is not None
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index ea54adf..aa1a791 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -24,6 +24,7 @@ import tvm_ffi
 def test_make_object() -> None:
     # with default values
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase")
+    assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
     assert obj0.v_i64 == 10
     assert obj0.v_f64 == 10.0
     assert obj0.v_str == "hello"
@@ -37,14 +38,16 @@ def test_make_object_via_init() -> None:
 
 def test_method() -> None:
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
-    assert obj0.add_i64(1) == 13
-    assert type(obj0).add_i64.__doc__ == "add_i64 method"
-    assert type(obj0).v_i64.__doc__ == "i64 field"
+    assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
+    assert obj0.add_i64(1) == 13  # type: ignore[attr-defined]
+    assert type(obj0).add_i64.__doc__ == "add_i64 method"  # type: 
ignore[attr-defined]
+    assert type(obj0).v_i64.__doc__ == "i64 field"  # type: 
ignore[attr-defined]
 
 
 def test_setter() -> None:
     # test setter
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, 
v_str="hello")
+    assert isinstance(obj0, tvm_ffi.testing.TestObjectBase)
     assert obj0.v_i64 == 10
     obj0.v_i64 = 11
     assert obj0.v_i64 == 11
@@ -52,10 +55,10 @@ def test_setter() -> None:
     assert obj0.v_str == "world"
 
     with pytest.raises(TypeError):
-        obj0.v_str = 1
+        obj0.v_str = 1  # type: ignore[assignment]
 
     with pytest.raises(TypeError):
-        obj0.v_i64 = "hello"
+        obj0.v_i64 = "hello"  # type: ignore[assignment]
 
 
 def test_derived_object() -> None:
@@ -68,6 +71,7 @@ def test_derived_object() -> None:
     obj0 = tvm_ffi.testing.create_object(
         "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array
     )
+    assert isinstance(obj0, tvm_ffi.testing.TestObjectDerived)
     assert obj0.v_map.same_as(v_map)
     assert obj0.v_array.same_as(v_array)
     assert obj0.v_i64 == 20
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index cfaf650..9280aab 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -15,12 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import annotations
+
+from types import ModuleType
+
 import pytest
 import tvm_ffi
 import tvm_ffi.cpp
 
+torch: ModuleType | None
 try:
-    import torch
+    import torch  # type: ignore[no-redef]
 except ImportError:
     torch = None
 
@@ -56,6 +61,7 @@ def test_raw_stream() -> None:
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
 def test_torch_stream() -> None:
+    assert torch is not None
     mod = gen_check_stream_mod()
     device_id = torch.cuda.current_device()
     device = tvm_ffi.device("cuda", device_id)
@@ -78,6 +84,7 @@ def test_torch_stream() -> None:
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
 def test_torch_current_stream() -> None:
+    assert torch is not None
     mod = gen_check_stream_mod()
     device_id = torch.cuda.current_device()
     device = tvm_ffi.device("cuda", device_id)
@@ -103,6 +110,7 @@ def test_torch_current_stream() -> None:
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
 def test_torch_graph() -> None:
+    assert torch is not None
     mod = gen_check_stream_mod()
     device_id = torch.cuda.current_device()
     device = tvm_ffi.device("cuda", device_id)
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 4c2e9a8..186d91b 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -14,10 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from __future__ import annotations
+
+from types import ModuleType
+
 import pytest
 
+torch: ModuleType | None
 try:
-    import torch
+    import torch  # type: ignore[no-redef]
 except ImportError:
     torch = None
 
@@ -45,18 +51,19 @@ def test_shape_object() -> None:
     assert shape == (10, 8, 4, 2)
 
     fecho = tvm_ffi.convert(lambda x: x)
-    shape2 = fecho(shape)
+    shape2: tvm_ffi.Shape = fecho(shape)
     assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
     assert isinstance(shape2, tvm_ffi.Shape)
     assert isinstance(shape2, tuple)
 
-    shape3 = tvm_ffi.convert(shape)
+    shape3: tvm_ffi.Shape = tvm_ffi.convert(shape)
     assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__)
     assert isinstance(shape3, tvm_ffi.Shape)
 
 
 @pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not 
enabled")
 def test_tensor_auto_dlpack() -> None:
+    assert torch is not None
     x = torch.arange(128)
     fecho = tvm_ffi.get_global_func("testing.echo")
     y = fecho(x)
diff --git a/tests/scripts/benchmark_dlpack.py 
b/tests/scripts/benchmark_dlpack.py
index 2d9b296..4366b58 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -254,9 +254,9 @@ def 
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str)
     x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
     y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
     z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
-    x = tvm_ffi.core.DLTensorTestWrapper(x)
-    y = tvm_ffi.core.DLTensorTestWrapper(y)
-    z = tvm_ffi.core.DLTensorTestWrapper(z)
+    x = tvm_ffi.core.DLTensorTestWrapper(x)  # type: ignore[assignment]
+    y = tvm_ffi.core.DLTensorTestWrapper(y)  # type: ignore[assignment]
+    z = tvm_ffi.core.DLTensorTestWrapper(z)  # type: ignore[assignment]
     bench_tvm_ffi_nop_autodlpack(
         f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, 
repeat
     )

Reply via email to