This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f4e044  chore: Enforce type annotation on public APIs (#21)
8f4e044 is described below

commit 8f4e044a90ff8a3db15742274f2c7df435142b04
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 17 17:03:25 2025 -0700

    chore: Enforce type annotation on public APIs (#21)
    
    ```diff
    [tool.ruff.lint]
    select = [
    ...
    -  # "ANN", # flake8-annotations, 
https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
    +  "ANN", # flake8-annotations, 
https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
    ...
    ]
    ```
---
 docs/conf.py                                       |  2 +-
 examples/inline_module/main.py                     |  2 +-
 .../packaging/python/my_ffi_extension/__init__.py  |  6 ++-
 examples/packaging/python/my_ffi_extension/base.py |  2 +-
 examples/packaging/run_example.py                  |  4 +-
 examples/quick_start/run_example.py                |  8 +--
 pyproject.toml                                     |  2 +-
 python/tvm_ffi/_dtype.py                           | 14 ++---
 python/tvm_ffi/_optional_torch_c_dlpack.py         |  3 +-
 python/tvm_ffi/_tensor.py                          |  7 +--
 python/tvm_ffi/base.py                             |  2 +-
 python/tvm_ffi/config.py                           |  4 +-
 python/tvm_ffi/container.py                        | 63 ++++++++++++----------
 python/tvm_ffi/cpp/load_inline.py                  |  6 ++-
 python/tvm_ffi/error.py                            | 28 ++++++----
 python/tvm_ffi/libinfo.py                          | 20 +++----
 python/tvm_ffi/module.py                           | 37 ++++++-------
 python/tvm_ffi/registry.py                         | 21 +++++---
 python/tvm_ffi/serialization.py                    |  4 +-
 python/tvm_ffi/stream.py                           | 22 ++++----
 python/tvm_ffi/testing.py                          |  4 +-
 python/tvm_ffi/utils/lockfile.py                   | 15 +++---
 tests/lint/check_asf_header.py                     | 15 +++---
 tests/lint/check_file_type.py                      |  8 +--
 tests/python/test_access_path.py                   | 16 +++---
 tests/python/test_container.py                     | 21 ++++----
 tests/python/test_device.py                        | 23 +++++---
 tests/python/test_dtype.py                         | 10 ++--
 tests/python/test_error.py                         | 16 +++---
 tests/python/test_examples.py                      | 10 ++--
 tests/python/test_function.py                      | 35 ++++++------
 tests/python/test_load_inline.py                   | 14 ++---
 tests/python/test_object.py                        | 13 ++---
 tests/python/test_stream.py                        | 10 ++--
 tests/python/test_string.py                        |  4 +-
 tests/python/test_tensor.py                        |  6 +--
 tests/scripts/benchmark_dlpack.py                  | 53 +++++++++---------
 37 files changed, 294 insertions(+), 236 deletions(-)

diff --git a/docs/conf.py b/docs/conf.py
index 2830711..b4188fe 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -182,7 +182,7 @@ footer_note = (
 )
 
 
-def footer_html():
+def footer_html() -> str:
     """Generate HTML for the documentation footer."""
     # Create footer HTML with two-line layout
     # Generate dropdown menu items
diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py
index 8afa9b5..7231132 100644
--- a/examples/inline_module/main.py
+++ b/examples/inline_module/main.py
@@ -21,7 +21,7 @@ import tvm_ffi.cpp
 from tvm_ffi.module import Module
 
 
-def main():
+def main() -> None:
     """Build, load, and run inline CPU/CUDA functions."""
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
diff --git a/examples/packaging/python/my_ffi_extension/__init__.py 
b/examples/packaging/python/my_ffi_extension/__init__.py
index 583945b..ae4abfd 100644
--- a/examples/packaging/python/my_ffi_extension/__init__.py
+++ b/examples/packaging/python/my_ffi_extension/__init__.py
@@ -17,11 +17,13 @@
 # isort: skip_file
 """Public Python API for the example tvm-ffi extension package."""
 
+from typing import Any
+
 from .base import _LIB
 from . import _ffi_api
 
 
-def add_one(x, y):
+def add_one(x: Any, y: Any) -> None:
     """Add one to the input tensor.
 
     Parameters
@@ -35,7 +37,7 @@ def add_one(x, y):
     return _LIB.add_one(x, y)
 
 
-def raise_error(msg):
+def raise_error(msg: str) -> None:
     """Raise an error with the given message.
 
     Parameters
diff --git a/examples/packaging/python/my_ffi_extension/base.py 
b/examples/packaging/python/my_ffi_extension/base.py
index 5b1546f..fb6f6c2 100644
--- a/examples/packaging/python/my_ffi_extension/base.py
+++ b/examples/packaging/python/my_ffi_extension/base.py
@@ -22,7 +22,7 @@ from pathlib import Path
 import tvm_ffi
 
 
-def _load_lib():
+def _load_lib() -> tvm_ffi.Module:
     # first look at the directory of the current file
     file_dir = Path(__file__).resolve().parent
 
diff --git a/examples/packaging/run_example.py 
b/examples/packaging/run_example.py
index 04650ec..6b5120f 100644
--- a/examples/packaging/run_example.py
+++ b/examples/packaging/run_example.py
@@ -22,7 +22,7 @@ import my_ffi_extension
 import torch
 
 
-def run_add_one():
+def run_add_one() -> None:
     """Invoke add_one from the extension and print the result."""
     x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
     y = torch.empty_like(x)
@@ -30,7 +30,7 @@ def run_add_one():
     print(y)
 
 
-def run_raise_error():
+def run_raise_error() -> None:
     """Invoke raise_error from the extension to demonstrate error handling."""
     my_ffi_extension.raise_error("This is an error")
 
diff --git a/examples/quick_start/run_example.py 
b/examples/quick_start/run_example.py
index 830c3c7..87c9507 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -27,7 +27,7 @@ except ImportError:
 import numpy
 
 
-def run_add_one_cpu():
+def run_add_one_cpu() -> None:
     """Load the add_one_cpu module and call the add_one_cpu function."""
     mod = tvm_ffi.load_module("build/add_one_cpu.so")
 
@@ -53,7 +53,7 @@ def run_add_one_cpu():
     print(y)
 
 
-def run_add_one_c():
+def run_add_one_c() -> None:
     """Load the add_one_c module and call the add_one_c function."""
     mod = tvm_ffi.load_module("build/add_one_c.so")
 
@@ -73,7 +73,7 @@ def run_add_one_c():
     print(y)
 
 
-def run_add_one_cuda():
+def run_add_one_cuda() -> None:
     """Load the add_one_cuda module and call the add_one_cuda function."""
     if torch is None or not torch.cuda.is_available():
         return
@@ -94,7 +94,7 @@ def run_add_one_cuda():
     print(y)
 
 
-def main():
+def main() -> None:
     """Run the quick start example."""
     run_add_one_cpu()
     run_add_one_c()
diff --git a/pyproject.toml b/pyproject.toml
index 7783c75..77a82d6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -132,7 +132,7 @@ select = [
   "RUF", # ruff, https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
   "NPY", # numpy, https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
   "F",   # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f
-  # "ANN", # flake8-annotations, 
https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
+  "ANN", # flake8-annotations, 
https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
   "PTH", # flake8-use-pathlib, 
https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth
   "D",   # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d
 ]
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index a76d111..ba1735f 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -62,16 +62,16 @@ class dtype(str):
 
     _NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}
 
-    def __new__(cls, content):
+    def __new__(cls, content: Any) -> "dtype":
         content = str(content)
         val = str.__new__(cls, content)
         val.__tvm_ffi_dtype__ = core.DataType(content)
         return val
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"dtype('{self}')"
 
-    def with_lanes(self, lanes):
+    def with_lanes(self, lanes: int) -> "dtype":
         """Create a new dtype with the given number of lanes.
 
         Parameters
@@ -96,19 +96,19 @@ class dtype(str):
         return val
 
     @property
-    def itemsize(self):
+    def itemsize(self) -> int:
         return self.__tvm_ffi_dtype__.itemsize
 
     @property
-    def type_code(self):
+    def type_code(self) -> int:
         return self.__tvm_ffi_dtype__.type_code
 
     @property
-    def bits(self):
+    def bits(self) -> int:
         return self.__tvm_ffi_dtype__.bits
 
     @property
-    def lanes(self):
+    def lanes(self) -> int:
         return self.__tvm_ffi_dtype__.lanes
 
 
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index b8b1f8f..500b684 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -31,11 +31,12 @@ subsequent calls will be much faster.
 """
 
 import warnings
+from typing import Any, Optional
 
 from . import libinfo
 
 
-def load_torch_c_dlpack_extension():
+def load_torch_c_dlpack_extension() -> Optional[Any]:
     """Load the torch c dlpack extension."""
     cpp_source = """
 #include <dlpack/dlpack.h>
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index bea20a9..903a69d 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -19,6 +19,7 @@
 # if we also want to expose a tensor function in the root namespace
 
 from numbers import Integral
+from typing import Any, Optional, Union
 
 from . import _ffi_api, core, registry
 from .core import Device, DLDeviceType, Tensor, from_dlpack
@@ -35,7 +36,7 @@ class Shape(tuple, core.PyNativeObject):
 
     """
 
-    def __new__(cls, content):
+    def __new__(cls, content: tuple[int, ...]) -> "Shape":
         if any(not isinstance(x, Integral) for x in content):
             raise ValueError("Shape must be a tuple of integers")
         val = tuple.__new__(cls, content)
@@ -43,7 +44,7 @@ class Shape(tuple, core.PyNativeObject):
         return val
 
     # pylint: disable=no-self-argument
-    def __from_tvm_ffi_object__(cls, obj):
+    def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape":
         """Construct from a given tvm object."""
         content = core._shape_obj_get_py_tuple(obj)
         val = tuple.__new__(cls, content)
@@ -51,7 +52,7 @@ class Shape(tuple, core.PyNativeObject):
         return val
 
 
-def device(device_type, index=None):
+def device(device_type: Union[str, int, DLDeviceType], index: Optional[int] = 
None) -> Device:
     """Construct a TVM FFI device with given device type and index.
 
     Parameters
diff --git a/python/tvm_ffi/base.py b/python/tvm_ffi/base.py
index 8099955..3ec2feb 100644
--- a/python/tvm_ffi/base.py
+++ b/python/tvm_ffi/base.py
@@ -38,7 +38,7 @@ if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 
9):
 # ----------------------------
 
 
-def _load_lib():
+def _load_lib() -> ctypes.CDLL:
     """Load libary by searching possible path."""
     lib_path = libinfo.find_libtvm_ffi()
     # The dll search path need to be added explicitly in windows
diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py
index bc31de4..9a418b5 100644
--- a/python/tvm_ffi/config.py
+++ b/python/tvm_ffi/config.py
@@ -23,7 +23,7 @@ from pathlib import Path
 from . import libinfo
 
 
-def find_windows_implib():
+def find_windows_implib() -> str:
     """Find and return the Windows import library path for tvm_ffi.lib."""
     libdir = Path(libinfo.find_libtvm_ffi()).parent
     implib = libdir / "tvm_ffi.lib"
@@ -32,7 +32,7 @@ def find_windows_implib():
     return str(implib)
 
 
-def __main__():  # noqa: PLR0912
+def __main__() -> None:  # noqa: PLR0912
     """Parse CLI args and print build and include configuration paths."""
     parser = argparse.ArgumentParser(
         description="Get various configuration information needed to compile 
with tvm-ffi"
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 9bb9f97..008dda9 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -17,8 +17,8 @@
 """Container classes."""
 
 import collections.abc
-from collections.abc import Mapping, Sequence
-from typing import Any
+from collections.abc import Iterator, Mapping, Sequence
+from typing import Any, Callable
 
 from . import _ffi_api, core
 from .registry import register_object
@@ -26,15 +26,20 @@ from .registry import register_object
 __all__ = ["Array", "Map"]
 
 
-def getitem_helper(obj, elem_getter, length, idx):
+def getitem_helper(
+    obj: Any,
+    elem_getter: Callable[[Any, int], Any],
+    length: int,
+    idx: int | slice,
+) -> Any:
     """Implement a pythonic __getitem__ helper.
 
     Parameters
     ----------
-    obj: object
+    obj: Any
         The original object
 
-    elem_getter : function
+    elem_getter : Callable[[Any, int], Any]
         A simple function that takes index and return a single element.
 
     length : int
@@ -93,19 +98,19 @@ class Array(core.Object, collections.abc.Sequence):
 
     """
 
-    def __init__(self, input_list: Sequence[Any]):
+    def __init__(self, input_list: Sequence[Any]) -> None:
         """Construct an Array from a Python sequence."""
         self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx: int | slice) -> Any:
         """Return one element or a Python list for a slice."""
         return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx)
 
-    def __len__(self):
+    def __len__(self) -> int:
         """Return the number of elements in the array."""
         return _ffi_api.ArraySize(self)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """Return a string representation of the array."""
         # exception safety handling for chandle=None
         if self.__chandle__() == 0:
@@ -116,13 +121,13 @@ class Array(core.Object, collections.abc.Sequence):
 class KeysView(collections.abc.KeysView):
     """Helper class to return keys view."""
 
-    def __init__(self, backend_map):
+    def __init__(self, backend_map: "Map") -> None:
         self._backend_map = backend_map
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._backend_map)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Any]:
         if self.__len__() == 0:
             return
         functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
@@ -132,20 +137,20 @@ class KeysView(collections.abc.KeysView):
             if not functor(2):
                 break
 
-    def __contains__(self, k):
+    def __contains__(self, k: Any) -> bool:
         return self._backend_map.__contains__(k)
 
 
 class ValuesView(collections.abc.ValuesView):
     """Helper class to return values view."""
 
-    def __init__(self, backend_map):
+    def __init__(self, backend_map: "Map") -> None:
         self._backend_map = backend_map
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._backend_map)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Any]:
         if self.__len__() == 0:
             return
         functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
@@ -159,13 +164,13 @@ class ValuesView(collections.abc.ValuesView):
 class ItemsView(collections.abc.ItemsView):
     """Helper class to return items view."""
 
-    def __init__(self, backend_map):
+    def __init__(self, backend_map: "Map") -> None:
         self.backend_map = backend_map
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.backend_map)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[tuple[Any, Any]]:
         if self.__len__() == 0:
             return
         functor = _ffi_api.MapForwardIterFunctor(self.backend_map)
@@ -206,7 +211,7 @@ class Map(core.Object, collections.abc.Mapping):
 
     """
 
-    def __init__(self, input_dict: Mapping[Any, Any]):
+    def __init__(self, input_dict: Mapping[Any, Any]) -> None:
         """Construct a Map from a Python mapping."""
         list_kvs = []
         for k, v in input_dict.items():
@@ -214,35 +219,35 @@ class Map(core.Object, collections.abc.Mapping):
             list_kvs.append(v)
         self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
 
-    def __getitem__(self, k):
+    def __getitem__(self, k: Any) -> Any:
         """Return the value for key `k` or raise KeyError."""
         return _ffi_api.MapGetItem(self, k)
 
-    def __contains__(self, k):
+    def __contains__(self, k: Any) -> bool:
         """Return True if the map contains key `k`."""
         return _ffi_api.MapCount(self, k) != 0
 
-    def keys(self):
+    def keys(self) -> KeysView:
         """Return a dynamic view of the map's keys."""
         return KeysView(self)
 
-    def values(self):
+    def values(self) -> "ValuesView":
         """Return a dynamic view of the map's values."""
         return ValuesView(self)
 
-    def items(self):
+    def items(self) -> ItemsView:
         """Get the items from the map."""
         return ItemsView(self)
 
-    def __len__(self):
+    def __len__(self) -> int:
         """Return the number of items in the map."""
         return _ffi_api.MapSize(self)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Any]:
         """Iterate over the map's keys."""
         return iter(self.keys())
 
-    def get(self, key, default=None):
+    def get(self, key: Any, default: Any | None = None) -> Any:
         """Get an element with a default value.
 
         Parameters
@@ -261,7 +266,7 @@ class Map(core.Object, collections.abc.Mapping):
         """
         return self[key] if key in self else default
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """Return a string representation of the map."""
         # exception safety handling for chandle=None
         if self.__chandle__() == 0:
diff --git a/python/tvm_ffi/cpp/load_inline.py 
b/python/tvm_ffi/cpp/load_inline.py
index 264a7bb..2c1caf5 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -130,7 +130,11 @@ def _get_cuda_target() -> str:
             return "-gencode=arch=compute_70,code=sm_70"
 
 
-def _run_command_in_dev_prompt(args, cwd, capture_output):
+def _run_command_in_dev_prompt(
+    args: list[str],
+    cwd: str | os.PathLike[str],
+    capture_output: bool,
+) -> subprocess.CompletedProcess:
     """Locates the Developer Command Prompt and runs a command within its 
environment."""
     try:
         # Path to vswhere.exe
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index 28788ef..66f068f 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -21,11 +21,12 @@ import ast
 import re
 import sys
 import types
+from typing import Any, Optional
 
 from . import core
 
 
-def _parse_traceback(traceback):
+def _parse_traceback(traceback: str) -> list[tuple[str, int, str]]:
     """Parse the traceback string into a list of (filename, lineno, func).
 
     Parameters
@@ -57,11 +58,11 @@ def _parse_traceback(traceback):
 class TracebackManager:
     """Helper to manage traceback generation."""
 
-    def __init__(self):
+    def __init__(self) -> None:
         """Initialize the traceback manager and its cache."""
         self._code_cache = {}
 
-    def _get_cached_code_object(self, filename, lineno, func):
+    def _get_cached_code_object(self, filename: str, lineno: int, func: str) 
-> types.CodeType:
         # Hack to create a code object that points to the correct
         # line number and function name
         key = (filename, lineno, func)
@@ -83,7 +84,7 @@ class TracebackManager:
         self._code_cache[key] = code_object
         return code_object
 
-    def _create_frame(self, filename, lineno, func):
+    def _create_frame(self, filename: str, lineno: int, func: str) -> 
types.FrameType:
         """Create a frame object from the filename, lineno, and func."""
         code_object = self._get_cached_code_object(filename, lineno, func)
         # call into get frame, but changes the context so the code
@@ -92,7 +93,13 @@ class TracebackManager:
         # pylint: disable=eval-used
         return eval(code_object, context, context)
 
-    def append_traceback(self, tb, filename, lineno, func):
+    def append_traceback(
+        self,
+        tb: Optional[types.TracebackType],
+        filename: str,
+        lineno: int,
+        func: str,
+    ) -> types.TracebackType:
         """Append a traceback to the given traceback.
 
         Parameters
@@ -119,7 +126,7 @@ class TracebackManager:
 _TRACEBACK_MANAGER = TracebackManager()
 
 
-def _with_append_traceback(py_error, traceback):
+def _with_append_traceback(py_error: BaseException, traceback: str) -> 
BaseException:
     """Append the traceback to the py_error and return it."""
     tb = py_error.__traceback__
     for filename, lineno, func in reversed(_parse_traceback(traceback)):
@@ -127,7 +134,7 @@ def _with_append_traceback(py_error, traceback):
     return py_error.with_traceback(tb)
 
 
-def _traceback_to_str(tb):
+def _traceback_to_str(tb: Optional[types.TracebackType]) -> str:
     """Convert the traceback to a string."""
     lines = []
     while tb is not None:
@@ -144,7 +151,10 @@ core._WITH_APPEND_TRACEBACK = _with_append_traceback
 core._TRACEBACK_TO_STR = _traceback_to_str
 
 
-def register_error(name_or_cls=None, cls=None):
+def register_error(
+    name_or_cls: str | type | None = None,
+    cls: Optional[type] = None,
+) -> Any:
     """Register an error class so it can be recognized by the ffi error 
handler.
 
     Parameters
@@ -176,7 +186,7 @@ def register_error(name_or_cls=None, cls=None):
         cls = name_or_cls
         name_or_cls = cls.__name__
 
-    def register(mycls):
+    def register(mycls: type) -> type:
         """Register the error class name with the FFI core."""
         err_name = name_or_cls if isinstance(name_or_cls, str) else 
mycls.__name__
         core.ERROR_NAME_TO_TYPE[err_name] = mycls
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index b707f2b..382690b 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -21,7 +21,7 @@ import sys
 from pathlib import Path
 
 
-def split_env_var(env_var, split):
+def split_env_var(env_var: str, split: str) -> list[str]:
     """Split an environment variable string.
 
     Parameters
@@ -43,7 +43,7 @@ def split_env_var(env_var, split):
     return []
 
 
-def get_dll_directories():
+def get_dll_directories() -> list[str]:
     """Get the possible dll directories."""
     ffi_dir = Path(__file__).expanduser().resolve().parent
     dll_path = [ffi_dir / "lib"]
@@ -62,7 +62,7 @@ def get_dll_directories():
     return [str(Path(x).resolve()) for x in dll_path if Path(x).is_dir()]
 
 
-def find_libtvm_ffi():
+def find_libtvm_ffi() -> str:
     """Find libtvm_ffi."""
     dll_path = get_dll_directories()
     if sys.platform.startswith("win32"):
@@ -82,7 +82,7 @@ def find_libtvm_ffi():
     return lib_found[0]
 
 
-def find_source_path():
+def find_source_path() -> str:
     """Find packaged source home path."""
     candidates = [
         str(Path(__file__).resolve().parent),
@@ -94,7 +94,7 @@ def find_source_path():
     raise RuntimeError("Cannot find home path.")
 
 
-def find_cmake_path():
+def find_cmake_path() -> str:
     """Find the preferred cmake path."""
     candidates = [
         str(Path(__file__).resolve().parent / "cmake"),
@@ -106,7 +106,7 @@ def find_cmake_path():
     raise RuntimeError("Cannot find cmake path.")
 
 
-def find_include_path():
+def find_include_path() -> str:
     """Find header files for C compilation."""
     candidates = [
         str(Path(__file__).resolve().parent / "include"),
@@ -118,7 +118,7 @@ def find_include_path():
     raise RuntimeError("Cannot find include path.")
 
 
-def find_python_helper_include_path():
+def find_python_helper_include_path() -> str:
     """Find header files for C compilation."""
     candidates = [
         str(Path(__file__).resolve().parent / "include"),
@@ -130,7 +130,7 @@ def find_python_helper_include_path():
     raise RuntimeError("Cannot find python helper include path.")
 
 
-def find_dlpack_include_path():
+def find_dlpack_include_path() -> str:
     """Find dlpack header files for C compilation."""
     install_include_path = Path(__file__).resolve().parent / "include"
     if (install_include_path / "dlpack").is_dir():
@@ -145,7 +145,7 @@ def find_dlpack_include_path():
     raise RuntimeError("Cannot find include path.")
 
 
-def find_cython_lib():
+def find_cython_lib() -> str:
     """Find the path to tvm cython."""
     path_candidates = [
         Path(__file__).resolve().parent,
@@ -158,7 +158,7 @@ def find_cython_lib():
     raise RuntimeError("Cannot find tvm cython path.")
 
 
-def include_paths():
+def include_paths() -> list[str]:
     """Find all include paths needed for FFI related compilation."""
     include_path = find_include_path()
     python_helper_include_path = find_python_helper_include_path()
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 335e262..acdc11e 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -18,6 +18,7 @@
 # pylint: disable=invalid-name
 
 from enum import IntEnum
+from typing import Any
 
 from . import _ffi_api, core
 from .registry import register_object
@@ -58,12 +59,12 @@ class Module(core.Object):
     entry_name = "main"
 
     @property
-    def kind(self):
+    def kind(self) -> str:
         """Get type key of the module."""
         return _ffi_api.ModuleGetKind(self)
 
     @property
-    def imports(self):
+    def imports(self) -> list["Module"]:
         """Get imported modules.
 
         Returns
@@ -74,7 +75,7 @@ class Module(core.Object):
         """
         return self.imports_
 
-    def implements_function(self, name, query_imports=False):
+    def implements_function(self, name: str, query_imports: bool = False) -> 
bool:
         """Return True if the module defines a global function.
 
         Note
@@ -101,7 +102,7 @@ class Module(core.Object):
         """
         return _ffi_api.ModuleImplementsFunction(self, name, query_imports)
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> core.Function:
         """Accessor to allow getting functions as attributes."""
         try:
             func = self.get_function(name)
@@ -110,7 +111,7 @@ class Module(core.Object):
         except AttributeError:
             raise AttributeError(f"Module has no function '{name}'")
 
-    def get_function(self, name, query_imports=False):
+    def get_function(self, name: str, query_imports: bool = False) -> 
core.Function:
         """Get function from the module.
 
         Parameters
@@ -132,7 +133,7 @@ class Module(core.Object):
             raise AttributeError(f"Module has no function '{name}'")
         return func
 
-    def import_module(self, module):
+    def import_module(self, module: "Module") -> None:
         """Add module to the import list of current one.
 
         Parameters
@@ -143,18 +144,18 @@ class Module(core.Object):
         """
         _ffi_api.ModuleImportModule(self, module)
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: str) -> core.Function:
         """Return function by name using item access (module["func"])."""
         if not isinstance(name, str):
             raise ValueError("Can only take string as function name")
         return self.get_function(name)
 
-    def __call__(self, *args):
+    def __call__(self, *args: Any) -> Any:
         """Call the module's entry function (`main`)."""
         # pylint: disable=not-callable
         return self.main(*args)
 
-    def inspect_source(self, fmt=""):
+    def inspect_source(self, fmt: str = "") -> str:
         """Get source code from module, if available.
 
         Parameters
@@ -170,11 +171,11 @@ class Module(core.Object):
         """
         return _ffi_api.ModuleInspectSource(self, fmt)
 
-    def get_write_formats(self):
+    def get_write_formats(self) -> list[str]:
         """Get the format of the module."""
         return _ffi_api.ModuleGetWriteFormats(self)
 
-    def get_property_mask(self):
+    def get_property_mask(self) -> int:
         """Get the runtime module property mask. The mapping is stated in 
ModulePropertyMask.
 
         Returns
@@ -185,7 +186,7 @@ class Module(core.Object):
         """
         return _ffi_api.ModuleGetPropertyMask(self)
 
-    def is_binary_serializable(self):
+    def is_binary_serializable(self) -> bool:
         """Return whether the module is binary serializable (supports 
save_to_bytes).
 
         Returns
@@ -196,7 +197,7 @@ class Module(core.Object):
         """
         return (self.get_property_mask() & 
ModulePropertyMask.BINARY_SERIALIZABLE) != 0
 
-    def is_runnable(self):
+    def is_runnable(self) -> bool:
         """Return whether the module is runnable (supports get_function).
 
         Returns
@@ -207,7 +208,7 @@ class Module(core.Object):
         """
         return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0
 
-    def is_compilation_exportable(self):
+    def is_compilation_exportable(self) -> bool:
         """Return whether the module is compilation exportable.
 
         write_to_file is supported for object or source.
@@ -220,11 +221,11 @@ class Module(core.Object):
         """
         return (self.get_property_mask() & 
ModulePropertyMask.COMPILATION_EXPORTABLE) != 0
 
-    def clear_imports(self):
+    def clear_imports(self) -> None:
         """Remove all imports of the module."""
         _ffi_api.ModuleClearImports(self)
 
-    def write_to_file(self, file_name, fmt=""):
+    def write_to_file(self, file_name: str, fmt: str = "") -> None:
         """Write the current module to file.
 
         Parameters
@@ -242,7 +243,7 @@ class Module(core.Object):
         _ffi_api.ModuleWriteToFile(self, file_name, fmt)
 
 
-def system_lib(symbol_prefix=""):
+def system_lib(symbol_prefix: str = "") -> Module:
     """Get system-wide library module singleton.
 
     System lib is a global module that contains self register functions in 
startup.
@@ -267,7 +268,7 @@ def system_lib(symbol_prefix=""):
     return _ffi_api.SystemLib(symbol_prefix)
 
 
-def load_module(path):
+def load_module(path: str) -> Module:
     """Load module from file.
 
     Parameters
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 60c8ded..1f4e340 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -17,6 +17,7 @@
 """FFI registry to register function and objects."""
 
 import sys
+from typing import Any, Callable, Optional
 
 from . import core
 
@@ -24,7 +25,7 @@ from . import core
 _SKIP_UNKNOWN_OBJECTS = False
 
 
-def register_object(type_key=None):
+def register_object(type_key: str | type | None = None) -> Any:
     """Register object type.
 
     Parameters
@@ -46,7 +47,7 @@ def register_object(type_key=None):
     """
     object_name = type_key if isinstance(type_key, str) else type_key.__name__
 
-    def register(cls):
+    def register(cls: type) -> type:
         """Register the object type with the FFI core."""
         type_index = core._object_type_key_to_index(object_name)
         if type_index is None:
@@ -63,7 +64,11 @@ def register_object(type_key=None):
     return register(type_key)
 
 
-def register_global_func(func_name, f=None, override=False):
+def register_global_func(
+    func_name: str | Callable[..., Any],
+    f: Optional[Callable[..., Any]] = None,
+    override: bool = False,
+) -> Any:
     """Register global function.
 
     Parameters
@@ -114,7 +119,7 @@ def register_global_func(func_name, f=None, override=False):
     if not isinstance(func_name, str):
         raise ValueError("expect string function name")
 
-    def register(myf):
+    def register(myf: Callable[..., Any]) -> Any:
         """Register the global function with the FFI core."""
         return core._register_global_func(func_name, myf, override)
 
@@ -123,7 +128,7 @@ def register_global_func(func_name, f=None, override=False):
     return register
 
 
-def get_global_func(name, allow_missing=False):
+def get_global_func(name: str, allow_missing: bool = False) -> 
Optional[core.Function]:
     """Get a global function by name.
 
     Parameters
@@ -147,7 +152,7 @@ def get_global_func(name, allow_missing=False):
     return core._get_global_func(name, allow_missing)
 
 
-def list_global_func_names():
+def list_global_func_names() -> list[str]:
     """Get list of global functions registered.
 
     Returns
@@ -161,7 +166,7 @@ def list_global_func_names():
     return [name_functor(i) for i in range(num_names)]
 
 
-def remove_global_func(name):
+def remove_global_func(name: str) -> None:
     """Remove a global function by name.
 
     Parameters
@@ -173,7 +178,7 @@ def remove_global_func(name):
     get_global_func("ffi.FunctionRemoveGlobal")(name)
 
 
-def init_ffi_api(namespace, target_module_name=None):
+def init_ffi_api(namespace: str, target_module_name: Optional[str] = None) -> 
None:
     """Initialize register ffi api  functions into a given module.
 
     Parameters
diff --git a/python/tvm_ffi/serialization.py b/python/tvm_ffi/serialization.py
index 803d533..2bc0d14 100644
--- a/python/tvm_ffi/serialization.py
+++ b/python/tvm_ffi/serialization.py
@@ -21,7 +21,7 @@ from typing import Any, Optional
 from . import _ffi_api
 
 
-def to_json_graph_str(obj: Any, metadata: Optional[dict] = None):
+def to_json_graph_str(obj: Any, metadata: Optional[dict] = None) -> str:
     """Dump an object to a JSON graph string.
 
     The JSON graph string is a string representation of of the object
@@ -45,7 +45,7 @@ def to_json_graph_str(obj: Any, metadata: Optional[dict] = 
None):
     return _ffi_api.ToJSONGraphString(obj, metadata)
 
 
-def from_json_graph_str(json_str: str):
+def from_json_graph_str(json_str: str) -> Any:
     """Load an object from a JSON graph string.
 
     The JSON graph string is a string representation of of the object
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 81cbabe..3ce9d94 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -18,7 +18,7 @@
 """Stream context."""
 
 from ctypes import c_void_p
-from typing import Any, Optional, Union
+from typing import Any, NoReturn, Optional, Union
 
 from . import core
 from ._tensor import device
@@ -46,19 +46,20 @@ class StreamContext:
 
     """
 
-    def __init__(self, device: core.Device, stream: Union[int, c_void_p]):
+    def __init__(self, device: core.Device, stream: Union[int, c_void_p]) -> 
None:
         """Initialize a stream context with a device and stream handle."""
         self.device_type = device.dlpack_device_type()
         self.device_id = device.index
         self.stream = stream
 
-    def __enter__(self):
+    def __enter__(self) -> "StreamContext":
         """Enter the context and set the current stream."""
         self.prev_stream = core._env_set_current_stream(
             self.device_type, self.device_id, self.stream
         )
+        return self
 
-    def __exit__(self, *args):
+    def __exit__(self, *args: Any) -> None:
         """Exit the context and restore the previous stream."""
         self.prev_stream = core._env_set_current_stream(
             self.device_type, self.device_id, self.prev_stream
@@ -71,11 +72,11 @@ try:
     class TorchStreamContext:
         """Context manager that syncs Torch and FFI stream contexts."""
 
-        def __init__(self, context: Optional[Any]):
+        def __init__(self, context: Optional[Any]) -> None:
             """Initialize with an optional Torch stream/graph context 
wrapper."""
             self.torch_context = context
 
-        def __enter__(self):
+        def __enter__(self) -> "TorchStreamContext":
             """Enter both Torch and FFI stream contexts."""
             if self.torch_context:
                 self.torch_context.__enter__()
@@ -84,14 +85,15 @@ try:
                 device(str(current_stream.device)), current_stream.cuda_stream
             )
             self.ffi_context.__enter__()
+            return self
 
-        def __exit__(self, *args):
+        def __exit__(self, *args: Any) -> None:
             """Exit both Torch and FFI stream contexts."""
             if self.torch_context:
                 self.torch_context.__exit__(*args)
             self.ffi_context.__exit__(*args)
 
-    def use_torch_stream(context: Optional[Any] = None):
+    def use_torch_stream(context: Optional[Any] = None) -> 
"TorchStreamContext":
         """Create an FFI stream context with a Torch stream or graph.
 
         cuda graph or current stream if `None` provided.
@@ -127,12 +129,12 @@ try:
 
 except ImportError:
 
-    def use_torch_stream(context: Optional[Any] = None):
+    def use_torch_stream(context: Optional[Any] = None) -> NoReturn:
         """Raise an informative error when Torch is unavailable."""
         raise ImportError("Cannot import torch")
 
 
-def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]):
+def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]) -> 
StreamContext:
     """Create a ffi stream context with given device and stream handle.
 
     Parameters
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 3c173dc..e58c115 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -16,6 +16,8 @@
 # under the License.
 """Testing utilities."""
 
+from typing import Any
+
 from . import _ffi_api
 from .core import Object
 from .registry import register_object
@@ -31,7 +33,7 @@ class TestObjectDerived(TestObjectBase):
     """Test object derived class."""
 
 
-def create_object(type_key: str, **kwargs) -> Object:
+def create_object(type_key: str, **kwargs: Any) -> Object:
     """Make an object by reflection.
 
     Parameters
diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py
index b317f04..243a319 100644
--- a/python/tvm_ffi/utils/lockfile.py
+++ b/python/tvm_ffi/utils/lockfile.py
@@ -19,6 +19,7 @@
 import os
 import sys
 import time
+from typing import Any, Optional
 
 # Platform-specific imports for file locking
 if sys.platform == "win32":
@@ -34,12 +35,12 @@ class FileLock:
     cooperating processes.
     """
 
-    def __init__(self, lock_file_path):
+    def __init__(self, lock_file_path: str) -> None:
         """Initialize a file lock using the given lock file path."""
         self.lock_file_path = lock_file_path
         self._file_descriptor = None
 
-    def __enter__(self):
+    def __enter__(self) -> "FileLock":
         """Acquire the lock upon entering the context.
 
         This method blocks until the lock is acquired.
@@ -47,12 +48,12 @@ class FileLock:
         self.blocking_acquire()
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
         """Context manager protocol: release the lock upon exiting the 'with' 
block."""
         self.release()
         return False  # Propagate exceptions, if any
 
-    def acquire(self):
+    def acquire(self) -> Optional[bool]:
         """Acquire an exclusive, non-blocking lock on the file.
 
         Returns True if the lock was acquired, False otherwise.
@@ -78,7 +79,9 @@ class FileLock:
                 self._file_descriptor = None
             raise RuntimeError(f"An unexpected error occurred: {e}")
 
-    def blocking_acquire(self, timeout=None, poll_interval=0.1):
+    def blocking_acquire(
+        self, timeout: Optional[float] = None, poll_interval: float = 0.1
+    ) -> Optional[bool]:
         """Wait until an exclusive lock can be acquired, with an optional 
timeout.
 
         Args:
@@ -100,7 +103,7 @@ class FileLock:
 
             time.sleep(poll_interval)
 
-    def release(self):
+    def release(self) -> None:
         """Releases the lock and closes the file descriptor."""
         if self._file_descriptor is not None:
             if sys.platform == "win32":
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index 9fcce8b..5ad7571 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -21,6 +21,7 @@ import fnmatch
 import subprocess
 import sys
 from pathlib import Path
+from typing import Optional
 
 header_cstyle = """
 /*
@@ -171,7 +172,7 @@ FMT_MAP = {
 SKIP_LIST = []
 
 
-def should_skip_file(filepath):
+def should_skip_file(filepath: str) -> bool:
     """Check if file should be skipped based on SKIP_LIST."""
     for pattern in SKIP_LIST:
         if fnmatch.fnmatch(filepath, pattern):
@@ -179,7 +180,7 @@ def should_skip_file(filepath):
     return False
 
 
-def get_git_files():
+def get_git_files() -> Optional[list[str]]:
     """Get list of files tracked by git."""
     try:
         result = subprocess.run(
@@ -196,7 +197,7 @@ def get_git_files():
         return None
 
 
-def copyright_line(line):
+def copyright_line(line: str) -> bool:
     # Following two items are intentionally break apart
     # so that the copyright detector won"t detect the file itself.
     if line.find("Copyright " + "(c)") != -1:
@@ -208,7 +209,7 @@ def copyright_line(line):
     return False
 
 
-def check_header(fname, header):
+def check_header(fname: str, header: str) -> bool:
     """Check header status of file without modifying it."""
     if not Path(fname).exists():
         print(f"ERROR: Cannot find {fname}")
@@ -240,7 +241,7 @@ def check_header(fname, header):
     return True
 
 
-def collect_files():
+def collect_files() -> Optional[list[str]]:
     """Collect all files that need header checking from git."""
     files = []
 
@@ -271,7 +272,7 @@ def collect_files():
     return files
 
 
-def add_header(fname, header):  # noqa: PLR0912
+def add_header(fname: str, header: str) -> None:  # noqa: PLR0912
     """Add header to file."""
     if not Path(fname).exists():
         print(f"Cannot find {fname} ...")
@@ -320,7 +321,7 @@ def add_header(fname, header):  # noqa: PLR0912
         print(f"Removed copyright line from {fname}")
 
 
-def main():  # noqa: PLR0911, PLR0912
+def main() -> int:  # noqa: PLR0911, PLR0912
     parser = argparse.ArgumentParser(
         description="Check and fix ASF headers in source files tracked by git",
         formatter_class=argparse.RawDescriptionHelpFormatter,
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index 9517168..e9fb40f 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -116,7 +116,7 @@ ALLOW_FILE_NAME = {
 ALLOW_SPECIFIC_FILE = {"LICENSE", "NOTICE", "KEYS", "DISCLAIMER"}
 
 
-def filename_allowed(name):
+def filename_allowed(name: str) -> bool:
     """Check if name is allowed by the current policy.
 
     Paramaters
@@ -146,7 +146,7 @@ def filename_allowed(name):
     return False
 
 
-def copyright_line(line):
+def copyright_line(line: str) -> bool:
     # Following two items are intentionally break apart
     # so that the copyright detector won't detect the file itself.
     if line.find("Copyright " + "(c)") != -1:
@@ -158,7 +158,7 @@ def copyright_line(line):
     return False
 
 
-def check_asf_copyright(fname):
+def check_asf_copyright(fname: str) -> bool:
     if fname.endswith(".png"):
         return True
     if not Path(fname).is_file():
@@ -178,7 +178,7 @@ def check_asf_copyright(fname):
     return True
 
 
-def main():
+def main() -> None:
     cmd = ["git", "ls-files"]
     proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
diff --git a/tests/python/test_access_path.py b/tests/python/test_access_path.py
index d3f59fb..6f0af34 100644
--- a/tests/python/test_access_path.py
+++ b/tests/python/test_access_path.py
@@ -19,7 +19,7 @@
 from tvm_ffi.access_path import AccessKind, AccessPath
 
 
-def test_root_path():
+def test_root_path() -> None:
     root = AccessPath.root()
     assert isinstance(root, AccessPath)
     steps = root.to_steps()
@@ -27,7 +27,7 @@ def test_root_path():
     assert root == AccessPath.root()
 
 
-def test_path_attr():
+def test_path_attr() -> None:
     path = AccessPath.root().attr("foo")
     assert isinstance(path, AccessPath)
     steps = path.to_steps()
@@ -37,7 +37,7 @@ def test_path_attr():
     assert path.parent == AccessPath.root()
 
 
-def test_path_array_item():
+def test_path_array_item() -> None:
     path = AccessPath.root().array_item(2)
     assert isinstance(path, AccessPath)
     steps = path.to_steps()
@@ -47,7 +47,7 @@ def test_path_array_item():
     assert path.parent == AccessPath.root()
 
 
-def test_path_missing_array_element():
+def test_path_missing_array_element() -> None:
     path = AccessPath.root().array_item_missing(2)
     assert isinstance(path, AccessPath)
     steps = path.to_steps()
@@ -57,7 +57,7 @@ def test_path_missing_array_element():
     assert path.parent == AccessPath.root()
 
 
-def test_path_map_item():
+def test_path_map_item() -> None:
     path = AccessPath.root().map_item("foo")
     assert isinstance(path, AccessPath)
     steps = path.to_steps()
@@ -67,7 +67,7 @@ def test_path_map_item():
     assert path.parent == AccessPath.root()
 
 
-def test_path_missing_map_item():
+def test_path_missing_map_item() -> None:
     path = AccessPath.root().map_item_missing("foo")
     assert isinstance(path, AccessPath)
     steps = path.to_steps()
@@ -77,7 +77,7 @@ def test_path_missing_map_item():
     assert path.parent == AccessPath.root()
 
 
-def test_path_is_prefix_of():
+def test_path_is_prefix_of() -> None:
     # Root is prefix of root
     assert AccessPath.root().is_prefix_of(AccessPath.root())
 
@@ -107,7 +107,7 @@ def test_path_is_prefix_of():
     )
 
 
-def test_path_equal():
+def test_path_equal() -> None:
     # Root equals root
     assert AccessPath.root() == AccessPath.root()
 
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index 9300bc7..f5f28f3 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 import pickle
+from typing import Any
 
 import pytest
 import tvm_ffi
 
 
-def test_array():
+def test_array() -> None:
     a = tvm_ffi.convert([1, 2, 3])
     assert isinstance(a, tvm_ffi.Array)
     assert len(a) == 3
@@ -29,7 +30,7 @@ def test_array():
     assert (a_slice[0], a_slice[1]) == (1, 2)
 
 
-def test_bad_constructor_init_state():
+def test_bad_constructor_init_state() -> None:
     """Test when error is raised before __init_handle_by_constructor.
 
     This case we need the FFI binding to gracefully handle both repr
@@ -43,7 +44,7 @@ def test_bad_constructor_init_state():
         tvm_ffi.Map(1)
 
 
-def test_array_of_array_map():
+def test_array_of_array_map() -> None:
     a = tvm_ffi.convert([[1, 2, 3], {"A": 5, "B": 6}])
     assert isinstance(a, tvm_ffi.Array)
     assert len(a) == 2
@@ -54,7 +55,7 @@ def test_array_of_array_map():
     assert a[1]["B"] == 6
 
 
-def test_int_map():
+def test_int_map() -> None:
     amap = tvm_ffi.convert({3: 2, 4: 3})
     assert 3 in amap
     assert len(amap) == 2
@@ -67,9 +68,9 @@ def test_int_map():
     assert tuple(amap.values()) == (2, 3)
 
 
-def test_array_map_of_opaque_object():
+def test_array_map_of_opaque_object() -> None:
     class MyObject:
-        def __init__(self, value):
+        def __init__(self, value: Any) -> None:
             self.value = value
 
     a = tvm_ffi.convert([MyObject("hello"), MyObject(1)])
@@ -89,7 +90,7 @@ def test_array_map_of_opaque_object():
     assert y["b"].value == "hello"
 
 
-def test_str_map():
+def test_str_map() -> None:
     data = []
     for i in reversed(range(10)):
         data.append((f"a{i}", i))
@@ -103,13 +104,13 @@ def test_str_map():
     assert tuple(k for k in amap) == tuple(k for k, _ in data)
 
 
-def test_key_not_found():
+def test_key_not_found() -> None:
     amap = tvm_ffi.convert({3: 2, 4: 3})
     with pytest.raises(KeyError):
         amap[5]
 
 
-def test_repr():
+def test_repr() -> None:
     a = tvm_ffi.convert([1, 2, 3])
     assert str(a) == "[1, 2, 3]"
     amap = tvm_ffi.convert({3: 2, 4: 3})
@@ -119,7 +120,7 @@ def test_repr():
     assert str(smap) == "{'a': 1, 'b': 2}"
 
 
-def test_serialization():
+def test_serialization() -> None:
     a = tvm_ffi.convert([1, 2, 3])
     b = pickle.loads(pickle.dumps(a))
     assert str(b) == "[1, 2, 3]"
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 1ee735c..30c964a 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -22,7 +22,7 @@ import tvm_ffi
 from tvm_ffi import DLDeviceType
 
 
-def test_device():
+def test_device() -> None:
     device = tvm_ffi.Device("cuda", 0)
     assert device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCUDA
     assert device.index == 0
@@ -30,7 +30,7 @@ def test_device():
     assert device.__repr__() == "device(type='cuda', index=0)"
 
 
-def test_device_from_str():
+def test_device_from_str() -> None:
     device = tvm_ffi.device("ext_dev:0")
     assert device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLExtDev
     assert device.index == 0
@@ -48,7 +48,11 @@ def test_device_from_str():
         ("metal:2", DLDeviceType.kDLMetal, 2),
     ],
 )
-def test_device_dlpack_device_type(dev_str, expected_device_type, 
expect_device_id):
+def test_device_dlpack_device_type(
+    dev_str: str,
+    expected_device_type: DLDeviceType,
+    expect_device_id: int,
+) -> None:
     dev = tvm_ffi.device(dev_str)
     assert dev.dlpack_device_type() == expected_device_type
     assert dev.index == expect_device_id
@@ -64,24 +68,29 @@ def test_device_dlpack_device_type(dev_str, 
expected_device_type, expect_device_
         (DLDeviceType.kDLMetal, 2, DLDeviceType.kDLMetal, 2),
     ],
 )
-def test_device_with_dev_id(dev_type, dev_id, expected_device_type, 
expect_device_id):
+def test_device_with_dev_id(
+    dev_type: str | DLDeviceType,
+    dev_id: int,
+    expected_device_type: DLDeviceType,
+    expect_device_id: int,
+) -> None:
     dev = tvm_ffi.device(dev_type, dev_id)
     assert dev.dlpack_device_type() == expected_device_type
     assert dev.index == expect_device_id
 
 
 @pytest.mark.parametrize("dev_type, dev_id", [("cpu:0:0", None), ("cpu:?", 
None), ("cpu:", None)])
-def test_deive_type_error(dev_type, dev_id):
+def test_deive_type_error(dev_type: str, dev_id: int | None) -> None:
     with pytest.raises(ValueError):
         tvm_ffi.device(dev_type, dev_id)
 
 
-def test_deive_id_error():
+def test_deive_id_error() -> None:
     with pytest.raises(TypeError):
         tvm_ffi.device("cpu", "?")
 
 
-def test_device_pickle():
+def test_device_pickle() -> None:
     device = tvm_ffi.device("cuda", 0)
     device_pickled = pickle.loads(pickle.dumps(device))
     assert device_pickled.dlpack_device_type() == device.dlpack_device_type()
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 9230ccc..61a0b7b 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -22,7 +22,7 @@ import pytest
 import tvm_ffi
 
 
-def test_dtype():
+def test_dtype() -> None:
     float32 = tvm_ffi.dtype("float32")
     assert float32.__repr__() == "dtype('float32')"
     assert type(float32) == tvm_ffi.dtype
@@ -42,13 +42,13 @@ def test_dtype():
         ("bool", 1),
     ],
 )
-def test_dtype_itemsize(dtype_str, expected_size):
+def test_dtype_itemsize(dtype_str: str, expected_size: int) -> None:
     dtype = tvm_ffi.dtype(dtype_str)
     assert dtype.itemsize == expected_size
 
 
 @pytest.mark.parametrize("dtype_str", ["int32xvscalex4"])
-def test_dtype_itemmize_error(dtype_str):
+def test_dtype_itemmize_error(dtype_str: str) -> None:
     with pytest.raises(ValueError):
         tvm_ffi.dtype(dtype_str).itemsize
 
@@ -65,7 +65,7 @@ def test_dtype_itemmize_error(dtype_str):
         "bool",
     ],
 )
-def test_dtype_pickle(dtype_str):
+def test_dtype_pickle(dtype_str: str) -> None:
     dtype = tvm_ffi.dtype(dtype_str)
     dtype_pickled = pickle.loads(pickle.dumps(dtype))
     assert dtype_pickled.type_code == dtype.type_code
@@ -74,7 +74,7 @@ def test_dtype_pickle(dtype_str):
 
 
 @pytest.mark.parametrize("dtype_str", ["float32", "bool"])
-def test_dtype_with_lanes(dtype_str):
+def test_dtype_with_lanes(dtype_str: str) -> None:
     dtype = tvm_ffi.dtype(dtype_str)
     dtype_with_lanes = dtype.with_lanes(4)
     assert dtype_with_lanes.type_code == dtype.type_code
diff --git a/tests/python/test_error.py b/tests/python/test_error.py
index 7b757ad..3b77ed7 100644
--- a/tests/python/test_error.py
+++ b/tests/python/test_error.py
@@ -16,11 +16,13 @@
 # under the License.
 
 
+from typing import NoReturn
+
 import pytest
 import tvm_ffi
 
 
-def test_parse_traceback():
+def test_parse_traceback() -> None:
     traceback = """
     File "test.py", line 1, in <module>
     File "test.py", line 3, in run_test
@@ -31,7 +33,7 @@ def test_parse_traceback():
     assert parsed[1] == ("test.py", 3, "run_test")
 
 
-def test_error_from_cxx():
+def test_error_from_cxx() -> None:
     test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error")
 
     try:
@@ -51,14 +53,14 @@ def test_error_from_cxx():
         tvm_ffi.convert(lambda x: x)()
 
 
-def test_error_from_nested_pyfunc():
+def test_error_from_nested_pyfunc() -> None:
     fapply = tvm_ffi.convert(lambda f, *args: f(*args))
     cxx_test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error")
     cxx_test_apply = tvm_ffi.get_global_func("testing.apply")
 
     record_object = []
 
-    def raise_error():
+    def raise_error() -> None:
         try:
             fapply(cxx_test_raise_error, "ValueError", "error XYZ")
         except ValueError as e:
@@ -87,10 +89,10 @@ def test_error_from_nested_pyfunc():
             pytest.xfail("May fail if debug symbols are missing")
 
 
-def test_error_traceback_update():
+def test_error_traceback_update() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
 
-    def raise_error():
+    def raise_error() -> NoReturn:
         raise ValueError("error XYZ")
 
     try:
@@ -99,7 +101,7 @@ def test_error_traceback_update():
         ffi_error = tvm_ffi.convert(e)
         assert ffi_error.traceback.find("raise_error") != -1
 
-    def raise_cxx_error():
+    def raise_cxx_error() -> None:
         cxx_test_raise_error = 
tvm_ffi.get_global_func("testing.test_raise_error")
         cxx_test_raise_error("ValueError", "error XYZ")
 
diff --git a/tests/python/test_examples.py b/tests/python/test_examples.py
index f8a9463..6f02144 100644
--- a/tests/python/test_examples.py
+++ b/tests/python/test_examples.py
@@ -15,13 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 # testcases appearing in example docstrings
+from typing import Any
+
 import tvm_ffi
 
 
-def test_register_global_func():
+def test_register_global_func() -> None:
     # we can use decorator to register a function
     @tvm_ffi.register_global_func("example.echo")
-    def echo(x):
+    def echo(x: Any) -> Any:
         return x
 
     # After registering, we can get the function by its name
@@ -33,13 +35,13 @@ def test_register_global_func():
     assert f(1) == 2
 
 
-def test_array():
+def test_array() -> None:
     a = tvm_ffi.convert([1, 2, 3])
     assert isinstance(a, tvm_ffi.Array)
     assert len(a) == 3
 
 
-def test_map():
+def test_map() -> None:
     amap = tvm_ffi.convert({"a": 1, "b": 2})
     assert isinstance(amap, tvm_ffi.Map)
     assert len(amap) == 2
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 43d9e1f..d0afd5f 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -18,12 +18,13 @@
 import ctypes
 import gc
 import sys
+from typing import Any
 
 import numpy as np
 import tvm_ffi
 
 
-def test_echo():
+def test_echo() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
     assert isinstance(fecho, tvm_ffi.Function)
     # test each type
@@ -75,7 +76,7 @@ def test_echo():
     assert fadd1(1, 2) == 3
     assert fadd1.same_as(fadd)
 
-    def check_tensor():
+    def check_tensor() -> None:
         np_data = np.arange(10, dtype="int32")
         if not hasattr(np_data, "__dlpack__"):
             return
@@ -92,13 +93,13 @@ def test_echo():
     check_tensor()
 
 
-def test_return_raw_str_bytes():
+def test_return_raw_str_bytes() -> None:
     assert tvm_ffi.convert(lambda: "hello")() == "hello"
     assert tvm_ffi.convert(lambda: b"hello")() == b"hello"
     assert tvm_ffi.convert(lambda: bytearray(b"hello"))() == b"hello"
 
 
-def test_string_bytes_passing():
+def test_string_bytes_passing() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
     use_count = tvm_ffi.get_global_func("testing.object_use_count")
     # small string
@@ -119,7 +120,7 @@ def test_string_bytes_passing():
     fecho(y) == 1
 
 
-def test_nested_container_passing():
+def test_nested_container_passing() -> None:
     # test and make sure our ref counting is correct
     fecho = tvm_ffi.get_global_func("testing.echo")
     use_count = tvm_ffi.get_global_func("testing.object_use_count")
@@ -131,24 +132,24 @@ def test_nested_container_passing():
     assert use_count(y[1]) == 2
 
 
-def test_pyfunc_convert():
-    def add(a, b):
+def test_pyfunc_convert() -> None:
+    def add(a: int, b: int) -> int:
         return a + b
 
     fadd = tvm_ffi.convert(add)
     assert isinstance(fadd, tvm_ffi.Function)
     assert fadd(1, 2) == 3
 
-    def fapply(f, *args):
+    def fapply(f: Any, *args: Any) -> Any:
         return f(*args)
 
     fapply = tvm_ffi.convert(fapply)
     assert fapply(add, 1, 3.3) == 4.3
 
 
-def test_global_func():
+def test_global_func() -> None:
     @tvm_ffi.register_global_func("mytest.echo")
-    def echo(x):
+    def echo(x: Any) -> Any:
         return x
 
     f = tvm_ffi.get_global_func("mytest.echo")
@@ -162,10 +163,10 @@ def test_global_func():
     assert tvm_ffi.get_global_func("mytest.echo", allow_missing=True) is None
 
 
-def test_rvalue_ref():
+def test_rvalue_ref() -> None:
     use_count = tvm_ffi.get_global_func("testing.object_use_count")
 
-    def callback(x, expected_count):
+    def callback(x: Any, expected_count: int) -> Any:
         # The use count of TVM FFI objects is decremented as part of
         # `ObjectRef.__del__`, which runs when the Python object is
         # destructed.  However, Python object destruction is not
@@ -179,14 +180,14 @@ def test_rvalue_ref():
 
     f = tvm_ffi.convert(callback)
 
-    def check0():
+    def check0() -> None:
         x = tvm_ffi.convert([1, 2])
         assert use_count(x) == 1
         f(x, 2)
         f(x._move(), 1)
         assert x.__ctypes_handle__().value is None
 
-    def check1():
+    def check1() -> None:
         x = tvm_ffi.convert([1, 2])
         assert use_count(x) == 1
         y = f(x, 2)
@@ -198,9 +199,9 @@ def test_rvalue_ref():
     check1()
 
 
-def test_echo_with_opaque_object():
+def test_echo_with_opaque_object() -> None:
     class MyObject:
-        def __init__(self, value):
+        def __init__(self, value: Any) -> None:
             self.value = value
 
     fecho = tvm_ffi.get_global_func("testing.echo")
@@ -211,7 +212,7 @@ def test_echo_with_opaque_object():
     assert y is x
     assert sys.getrefcount(x) == 3
 
-    def py_callback(z):
+    def py_callback(z: Any) -> Any:
         """Python callback with opaque object."""
         assert z is x
         return z
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 25dcee4..5454284 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -28,7 +28,7 @@ import tvm_ffi.cpp
 from tvm_ffi.module import Module
 
 
-def test_load_inline_cpp():
+def test_load_inline_cpp() -> None:
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""
@@ -54,7 +54,7 @@ def test_load_inline_cpp():
     numpy.testing.assert_equal(x + 1, y)
 
 
-def test_load_inline_cpp_with_docstrings():
+def test_load_inline_cpp_with_docstrings() -> None:
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""
@@ -80,7 +80,7 @@ def test_load_inline_cpp_with_docstrings():
     numpy.testing.assert_equal(x + 1, y)
 
 
-def test_load_inline_cpp_multiple_sources():
+def test_load_inline_cpp_multiple_sources() -> None:
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=[
@@ -122,7 +122,7 @@ def test_load_inline_cpp_multiple_sources():
     numpy.testing.assert_equal(x + 1, y)
 
 
-def test_load_inline_cpp_build_dir():
+def test_load_inline_cpp_build_dir() -> None:
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""
@@ -152,7 +152,7 @@ def test_load_inline_cpp_build_dir():
 @pytest.mark.skipif(
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
-def test_load_inline_cuda():
+def test_load_inline_cuda() -> None:
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cuda_sources=r"""
@@ -196,7 +196,7 @@ def test_load_inline_cuda():
 
 
 @pytest.mark.skipif(torch is None, reason="Requires torch")
-def test_load_inline_with_env_tensor_allocator():
+def test_load_inline_with_env_tensor_allocator() -> None:
     if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"):
         pytest.skip("Torch does not support __c_dlpack_tensor_allocator__")
     mod: Module = tvm_ffi.cpp.load_inline(
@@ -240,7 +240,7 @@ def test_load_inline_with_env_tensor_allocator():
 @pytest.mark.skipif(
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
-def test_load_inline_both():
+def test_load_inline_both() -> None:
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 2da92e8..bcfb52e 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 import sys
+from typing import Any
 
 import pytest
 import tvm_ffi
 
 
-def test_make_object():
+def test_make_object() -> None:
     # with default values
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase")
     assert obj0.v_i64 == 10
@@ -28,14 +29,14 @@ def test_make_object():
     assert obj0.v_str == "hello"
 
 
-def test_method():
+def test_method() -> None:
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
     assert obj0.add_i64(1) == 13
     assert type(obj0).add_i64.__doc__ == "add_i64 method"
     assert type(obj0).v_i64.__doc__ == "i64 field"
 
 
-def test_setter():
+def test_setter() -> None:
     # test setter
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, 
v_str="hello")
     assert obj0.v_i64 == 10
@@ -51,7 +52,7 @@ def test_setter():
         obj0.v_i64 = "hello"
 
 
-def test_derived_object():
+def test_derived_object() -> None:
     with pytest.raises(TypeError):
         obj0 = tvm_ffi.testing.create_object("testing.TestObjectDerived")
 
@@ -72,11 +73,11 @@ def test_derived_object():
 
 
 class MyObject:
-    def __init__(self, value):
+    def __init__(self, value: Any) -> None:
         self.value = value
 
 
-def test_opaque_object():
+def test_opaque_object() -> None:
     obj0 = MyObject("hello")
     assert sys.getrefcount(obj0) == 2
     obj0_converted = tvm_ffi.convert(obj0)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index 34e5ccc..cfaf650 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -25,7 +25,7 @@ except ImportError:
     torch = None
 
 
-def gen_check_stream_mod():
+def gen_check_stream_mod() -> tvm_ffi.Module:
     return tvm_ffi.cpp.load_inline(
         name="check_stream",
         cpp_sources="""
@@ -38,7 +38,7 @@ def gen_check_stream_mod():
     )
 
 
-def test_raw_stream():
+def test_raw_stream() -> None:
     mod = gen_check_stream_mod()
     device = tvm_ffi.device("cuda:0")
     stream_1 = 123456789
@@ -55,7 +55,7 @@ def test_raw_stream():
 @pytest.mark.skipif(
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
-def test_torch_stream():
+def test_torch_stream() -> None:
     mod = gen_check_stream_mod()
     device_id = torch.cuda.current_device()
     device = tvm_ffi.device("cuda", device_id)
@@ -77,7 +77,7 @@ def test_torch_stream():
 @pytest.mark.skipif(
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
-def test_torch_current_stream():
+def test_torch_current_stream() -> None:
     mod = gen_check_stream_mod()
     device_id = torch.cuda.current_device()
     device = tvm_ffi.device("cuda", device_id)
@@ -102,7 +102,7 @@ def test_torch_current_stream():
 @pytest.mark.skipif(
     torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
 )
-def test_torch_graph():
+def test_torch_graph() -> None:
     mod = gen_check_stream_mod()
     device_id = torch.cuda.current_device()
     device = tvm_ffi.device("cuda", device_id)
diff --git a/tests/python/test_string.py b/tests/python/test_string.py
index 36b8db8..5dd06f4 100644
--- a/tests/python/test_string.py
+++ b/tests/python/test_string.py
@@ -20,7 +20,7 @@ import pickle
 import tvm_ffi
 
 
-def test_string():
+def test_string() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
     s = tvm_ffi.core.String("hello")
     s2 = fecho(s)
@@ -35,7 +35,7 @@ def test_string():
     assert s4 == "hello"
 
 
-def test_bytes():
+def test_bytes() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
     b = tvm_ffi.core.Bytes(b"hello")
     assert isinstance(b, tvm_ffi.core.Bytes)
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 13cb76b..6d1da26 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -25,7 +25,7 @@ import numpy as np
 import tvm_ffi
 
 
-def test_tensor_attributes():
+def test_tensor_attributes() -> None:
     data = np.zeros((10, 8, 4, 2), dtype="int16")
     if not hasattr(data, "__dlpack__"):
         return
@@ -39,7 +39,7 @@ def test_tensor_attributes():
     np.testing.assert_equal(x2, data)
 
 
-def test_shape_object():
+def test_shape_object() -> None:
     shape = tvm_ffi.Shape((10, 8, 4, 2))
     assert isinstance(shape, tvm_ffi.Shape)
     assert shape == (10, 8, 4, 2)
@@ -56,7 +56,7 @@ def test_shape_object():
 
 
 @pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not 
enabled")
-def test_tensor_auto_dlpack():
+def test_tensor_auto_dlpack() -> None:
     x = torch.arange(128)
     fecho = tvm_ffi.get_global_func("testing.echo")
     y = fecho(x)
diff --git a/tests/scripts/benchmark_dlpack.py 
b/tests/scripts/benchmark_dlpack.py
index bef5284..2d9b296 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -31,24 +31,25 @@ Summary of some takeaways:
 """
 
 import time
+from typing import Any, Callable
 
 import numpy as np
 import torch
 import tvm_ffi
 
 
-def print_speed(name, speed):
+def print_speed(name: str, speed: float) -> None:
     print(f"{name:<60} {speed} sec/call")
 
 
-def print_error(name, error):
+def print_error(name: str, error: Any) -> None:
     print(f"{name:<60} {error}")
 
 
-def baseline_torch_add(repeat):
+def baseline_torch_add(repeat: int) -> None:
     """Run torch.add with one element."""
 
-    def run_bench(device):
+    def run_bench(device: str) -> None:
         x = torch.arange(1, device=device)
         y = torch.arange(1, device=device)
         z = torch.arange(1, device=device)
@@ -69,7 +70,7 @@ def baseline_torch_add(repeat):
     run_bench("cuda")
 
 
-def baseline_numpy_add(repeat):
+def baseline_numpy_add(repeat: int) -> None:
     """Run numpy.add with one element."""
     x = np.arange(1)
     y = np.arange(1)
@@ -84,7 +85,7 @@ def baseline_numpy_add(repeat):
     print_speed("numpy.add", speed)
 
 
-def baseline_cupy_add(repeat):
+def baseline_cupy_add(repeat: int) -> None:
     """Run cupy.add with one element."""
     try:
         import cupy  # noqa: PLC0415
@@ -104,7 +105,7 @@ def baseline_cupy_add(repeat):
     print_speed("cupy.add", speed)
 
 
-def tvm_ffi_nop(repeat):
+def tvm_ffi_nop(repeat: int) -> None:
     """Overhead of tvm FFI python call via calling a NOP.
 
     testing.nop is defined in c++ and do nothing.
@@ -121,7 +122,7 @@ def tvm_ffi_nop(repeat):
     print_speed("tvm_ffi.nop", (end - start) / repeat)
 
 
-def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
+def bench_ffi_nop_from_dlpack(name: str, x: Any, y: Any, z: Any, repeat: int) 
-> None:
     """Run dlpack conversion + tvm_ffi.nop.
 
     Measures overhead of running dlpack for each args then invoke
@@ -142,7 +143,7 @@ def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
     print_speed(name, (end - start) / repeat)
 
 
-def tvm_ffi_nop_from_torch_dlpack(repeat):
+def tvm_ffi_nop_from_torch_dlpack(repeat: int) -> None:
     """Run dlpack conversion + tvm_ffi.nop.
 
     Measures overhead of running dlpack for each args then invoke
@@ -153,7 +154,7 @@ def tvm_ffi_nop_from_torch_dlpack(repeat):
     bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(torch)", x, y, z, 
repeat)
 
 
-def tvm_ffi_nop_from_numpy_dlpack(repeat):
+def tvm_ffi_nop_from_numpy_dlpack(repeat: int) -> None:
     """Run dlpack conversion + tvm_ffi.nop.
 
     Measures overhead of running dlpack for each args then invoke
@@ -164,7 +165,7 @@ def tvm_ffi_nop_from_numpy_dlpack(repeat):
     bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(numpy)", x, y, z, 
repeat)
 
 
-def tvm_ffi_self_dlpack_nop(repeat):
+def tvm_ffi_self_dlpack_nop(repeat: int) -> None:
     """Run dlpack conversion + tvm_ffi.nop.
 
     Measures overhead of running dlpack for each args then invoke
@@ -175,7 +176,7 @@ def tvm_ffi_self_dlpack_nop(repeat):
     bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(tvm)", x, y, z, repeat)
 
 
-def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
+def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat: int) -> None:
     """Measures overhead of running dlpack for each args then invoke
     but uses the legacy torch.utils.dlpack.to_dlpack API.
 
@@ -202,7 +203,7 @@ def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
     print_speed("tvm_ffi.nop+from_dlpack(torch.utils)", speed)
 
 
-def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
+def bench_tvm_ffi_nop_autodlpack(name: str, x: Any, y: Any, z: Any, repeat: 
int) -> None:
     """Measures overhead of running dlpack via auto convert by directly
     take torch.Tensor as inputs.
     """
@@ -216,7 +217,9 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
     print_speed(name, speed)
 
 
-def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False):
+def tvm_ffi_nop_autodlpack_from_torch(
+    repeat: int, device: str = "cpu", stream: bool = False
+) -> None:
     """Measures overhead of running dlpack via auto convert by directly
     take torch.Tensor as inputs.
     """
@@ -233,7 +236,7 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", 
stream=False):
         
bench_tvm_ffi_nop_autodlpack(f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y, 
z, repeat)
 
 
-def tvm_ffi_nop_autodlpack_from_numpy(repeat):
+def tvm_ffi_nop_autodlpack_from_numpy(repeat: int) -> None:
     """Measures overhead of running dlpack via auto convert by directly
     take numpy.ndarray as inputs.
     """
@@ -244,7 +247,7 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat):
     bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, 
repeat)
 
 
-def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
+def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: 
str) -> None:
     """Measures overhead of running dlpack via auto convert by directly
     take test wrapper as inputs. This effectively measure DLPack exchange in 
tvm ffi.
     """
@@ -259,7 +262,7 @@ def 
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
     )
 
 
-def bench_to_dlpack(x, name, repeat):
+def bench_to_dlpack(x: Any, name: str, repeat: int) -> None:
     x.__dlpack__()
     start = time.time()
     for i in range(repeat):
@@ -269,7 +272,9 @@ def bench_to_dlpack(x, name, repeat):
     print_speed(name, speed)
 
 
-def bench_to_dlpack_versioned(x, name, repeat, max_version=(1, 1)):
+def bench_to_dlpack_versioned(
+    x: Any, name: str, repeat: int, max_version: tuple[int, int] = (1, 1)
+) -> None:
     """Measures overhead of running dlpack with latest 1.1."""
     try:
         x.__dlpack__(max_version=max_version)
@@ -283,7 +288,7 @@ def bench_to_dlpack_versioned(x, name, repeat, 
max_version=(1, 1)):
         print_error(name, e)
 
 
-def bench_torch_utils_to_dlpack(repeat):
+def bench_torch_utils_to_dlpack(repeat: int) -> None:
     """Measures overhead of running torch.utils.dlpack.to_dlpack."""
     x = torch.arange(1)
     torch.utils.dlpack.to_dlpack(x)
@@ -295,11 +300,11 @@ def bench_torch_utils_to_dlpack(repeat):
     print_speed("torch.utils.dlpack.to_dlpack", speed)
 
 
-def torch_get_cuda_stream_native(device_id):
+def torch_get_cuda_stream_native(device_id: int) -> int:
     return torch.cuda.current_stream(device_id).cuda_stream
 
 
-def load_torch_get_current_cuda_stream():
+def load_torch_get_current_cuda_stream() -> Callable[[int], int]:
     """Create a faster get_current_cuda_stream for torch through cpp 
extension."""
     from torch.utils import cpp_extension  # noqa: PLC0415
 
@@ -325,7 +330,7 @@ def load_torch_get_current_cuda_stream():
     return result.get_current_cuda_stream
 
 
-def bench_torch_get_current_stream(repeat, name, func):
+def bench_torch_get_current_stream(repeat: int, name: str, func: 
Callable[[int], int]) -> None:
     """Measures overhead of running torch.cuda.current_stream."""
     x = torch.arange(1, device="cuda")  # noqa: F841
     func(0)
@@ -337,14 +342,14 @@ def bench_torch_get_current_stream(repeat, name, func):
     print_speed(f"torch.cuda.current_stream[{name}]", speed)
 
 
-def populate_object_table(num_classes):
+def populate_object_table(num_classes: int) -> None:
     nop = tvm_ffi.get_global_func("testing.nop")
     dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in 
range(num_classes)]
     for instance in dummy_instances:
         nop(instance)
 
 
-def main():  # noqa: PLR0915
+def main() -> None:  # noqa: PLR0915
     repeat = 10000
     # measures impact of object dispatch table size
     # takeaway so far is that there is no impact on the performance

Reply via email to