This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 0bc968d chore: Basic Ruff rule coverage (#20)
0bc968d is described below
commit 0bc968d1c6c76db80e69d2e860eafc8e0a3e9a69
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 17 15:38:11 2025 -0700
chore: Basic Ruff rule coverage (#20)
Enables a subset of Ruff rules including:
```
select = [
"UP", # pyupgrade, https://docs.astral.sh/ruff/rules/#pyupgrade-up
"PL", # pylint, https://docs.astral.sh/ruff/rules/#pylint-pl
"I", # isort, https://docs.astral.sh/ruff/rules/#isort-i
"RUF", # ruff, https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
"NPY", # numpy,
https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
"F", # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f
# "ANN", # flake8-annotations,
https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
"PTH", # flake8-use-pathlib,
https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth
"D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d
]
```
---
docs/conf.py | 7 +-
examples/inline_module/main.py | 3 +-
.../packaging/python/my_ffi_extension/__init__.py | 9 +-
.../packaging/python/my_ffi_extension/_ffi_api.py | 1 +
examples/packaging/python/my_ffi_extension/base.py | 10 +-
examples/packaging/run_example.py | 4 +
examples/quick_start/run_example.py | 4 +-
pyproject.toml | 90 ++++++++++-------
python/tvm_ffi/__init__.py | 36 +++----
python/tvm_ffi/_convert.py | 3 +-
python/tvm_ffi/_dtype.py | 10 +-
python/tvm_ffi/_optional_torch_c_dlpack.py | 7 +-
python/tvm_ffi/_tensor.py | 8 +-
python/tvm_ffi/access_path.py | 39 ++++---
python/tvm_ffi/config.py | 41 +++-----
python/tvm_ffi/container.py | 35 +++++--
python/tvm_ffi/cpp/__init__.py | 1 +
python/tvm_ffi/cpp/load_inline.py | 112 +++++++++------------
python/tvm_ffi/error.py | 20 ++--
python/tvm_ffi/libinfo.py | 83 +++++++--------
python/tvm_ffi/module.py | 43 +++++---
python/tvm_ffi/registry.py | 29 +++---
python/tvm_ffi/serialization.py | 10 +-
python/tvm_ffi/stream.py | 22 +++-
python/tvm_ffi/testing.py | 12 +--
python/tvm_ffi/utils/__init__.py | 1 +
python/tvm_ffi/utils/lockfile.py | 34 +++----
tests/lint/check_asf_header.py | 45 ++++-----
tests/lint/check_file_type.py | 20 ++--
tests/lint/git-clang-format.sh | 92 -----------------
tests/python/test_access_path.py | 27 ++---
tests/python/test_container.py | 3 +-
tests/python/test_device.py | 5 +-
tests/python/test_dtype.py | 1 -
tests/python/test_error.py | 1 -
tests/python/test_function.py | 3 +-
tests/python/test_object.py | 5 +-
tests/python/test_stream.py | 1 -
tests/python/test_tensor.py | 1 -
tests/scripts/benchmark_dlpack.py | 87 ++++++----------
tests/scripts/task_lint.sh | 46 ---------
41 files changed, 434 insertions(+), 577 deletions(-)
diff --git a/docs/conf.py b/docs/conf.py
index e621878..2830711 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Sphinx configuration for the tvm-ffi documentation site."""
+
# -*- coding: utf-8 -*-
import os
+from pathlib import Path
import tomli
@@ -25,9 +28,8 @@ build_exhale = os.environ.get("BUILD_CPP_DOCS", "0") == "1"
# -- General configuration ------------------------------------------------
-
# Load version from pyproject.toml
-with open("../pyproject.toml", "rb") as f:
+with Path("../pyproject.toml").open("rb") as f:
pyproject_data = tomli.load(f)
__version__ = pyproject_data["project"]["version"]
@@ -181,6 +183,7 @@ footer_note = (
def footer_html():
+ """Generate HTML for the documentation footer."""
# Create footer HTML with two-line layout
# Generate dropdown menu items
dropdown_items = ""
diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py
index 98b939e..8afa9b5 100644
--- a/examples/inline_module/main.py
+++ b/examples/inline_module/main.py
@@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Example: Build and run an inline C++/CUDA tvm-ffi module."""
import torch
-
import tvm_ffi.cpp
from tvm_ffi.module import Module
def main():
+ """Build, load, and run inline CPU/CUDA functions."""
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
diff --git a/examples/packaging/python/my_ffi_extension/__init__.py
b/examples/packaging/python/my_ffi_extension/__init__.py
index 766a099..583945b 100644
--- a/examples/packaging/python/my_ffi_extension/__init__.py
+++ b/examples/packaging/python/my_ffi_extension/__init__.py
@@ -15,14 +15,14 @@
# specific language governing permissions and limitations.
# order matters here so we need to skip isort here
# isort: skip_file
+"""Public Python API for the example tvm-ffi extension package."""
from .base import _LIB
from . import _ffi_api
def add_one(x, y):
- """
- Adds one to the input tensor.
+ """Add one to the input tensor.
Parameters
----------
@@ -30,13 +30,13 @@ def add_one(x, y):
The input tensor.
y : Tensor
The output tensor.
+
"""
return _LIB.add_one(x, y)
def raise_error(msg):
- """
- Raises an error with the given message.
+ """Raise an error with the given message.
Parameters
----------
@@ -47,5 +47,6 @@ def raise_error(msg):
------
RuntimeError
The error raised by the function.
+
"""
return _ffi_api.raise_error(msg)
diff --git a/examples/packaging/python/my_ffi_extension/_ffi_api.py
b/examples/packaging/python/my_ffi_extension/_ffi_api.py
index 5e03489..edc7677 100644
--- a/examples/packaging/python/my_ffi_extension/_ffi_api.py
+++ b/examples/packaging/python/my_ffi_extension/_ffi_api.py
@@ -17,6 +17,7 @@
import tvm_ffi
# make sure lib is loaded first
+from .base import _LIB # noqa: F401
# this is a short cut to register all the global functions
# prefixed by `my_ffi_extension.` to this module
diff --git a/examples/packaging/python/my_ffi_extension/base.py
b/examples/packaging/python/my_ffi_extension/base.py
index fa17252..5b1546f 100644
--- a/examples/packaging/python/my_ffi_extension/base.py
+++ b/examples/packaging/python/my_ffi_extension/base.py
@@ -14,15 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations.
# Base logic to load library for extension package
-import os
+"""Utilities to locate and load the example extension shared library."""
+
import sys
+from pathlib import Path
import tvm_ffi
def _load_lib():
# first look at the directory of the current file
- file_dir = os.path.dirname(os.path.realpath(__file__))
+ file_dir = Path(__file__).resolve().parent
if sys.platform.startswith("win32"):
lib_dll_name = "my_ffi_extension.dll"
@@ -31,8 +33,8 @@ def _load_lib():
else:
lib_dll_name = "my_ffi_extension.so"
- lib_path = os.path.join(file_dir, lib_dll_name)
- return tvm_ffi.load_module(lib_path)
+ lib_path = file_dir / lib_dll_name
+ return tvm_ffi.load_module(str(lib_path))
_LIB = _load_lib()
diff --git a/examples/packaging/run_example.py
b/examples/packaging/run_example.py
index 5304409..04650ec 100644
--- a/examples/packaging/run_example.py
+++ b/examples/packaging/run_example.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations.
# Base logic to load library for extension package
+"""Run functions from the example packaged tvm-ffi extension."""
+
import sys
import my_ffi_extension
@@ -21,6 +23,7 @@ import torch
def run_add_one():
+ """Invoke add_one from the extension and print the result."""
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
my_ffi_extension.add_one(x, y)
@@ -28,6 +31,7 @@ def run_add_one():
def run_raise_error():
+ """Invoke raise_error from the extension to demonstrate error handling."""
my_ffi_extension.raise_error("This is an error")
diff --git a/examples/quick_start/run_example.py
b/examples/quick_start/run_example.py
index 698bc2a..830c3c7 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Quick start script to run tvm-ffi examples from prebuilt libraries."""
+
import tvm_ffi
try:
@@ -93,7 +95,7 @@ def run_add_one_cuda():
def main():
- """Main function to run the example."""
+ """Run the quick start example."""
run_add_one_cpu()
run_add_one_c()
run_add_one_cuda()
diff --git a/pyproject.toml b/pyproject.toml
index e047529..7783c75 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -67,7 +67,7 @@ cmake.build-type = "Release"
cmake.args = [
"-DTVM_FFI_ATTACH_DEBUG_SYMBOLS=ON",
"-DTVM_FFI_BUILD_TESTS=OFF",
- "-DTVM_FFI_BUILD_PYTHON_MODULE=ON"
+ "-DTVM_FFI_BUILD_PYTHON_MODULE=ON",
]
# Logging
@@ -106,64 +106,76 @@ sdist.include = [
"/tests/**/*",
]
-sdist.exclude = ["**/.git", "**/.github", "**/__pycache__", "**/*.pyc",
"build", "dist"]
+sdist.exclude = [
+ "**/.git",
+ "**/.github",
+ "**/__pycache__",
+ "**/*.pyc",
+ "build",
+ "dist",
+]
[tool.pytest.ini_options]
testpaths = ["tests"]
-[tool.black]
-line-length = 100
-skip-magic-trailing-comma = true
-
-exclude = '''
-/(
- \.venv
- | build
- | docs
- | dist
- | 3rdparty/*
-)/
-'''
-
-[tool.isort]
-profile = "black"
-src_paths = ["python", "tests"]
-extend_skip = ["3rdparty"]
-line_length = 100
-skip_gitignore = true
-
[tool.ruff]
include = ["python/**/*.py", "tests/**/*.py"]
+line-length = 100
+indent-width = 4
+target-version = "py39"
[tool.ruff.lint]
+select = [
+ "UP", # pyupgrade, https://docs.astral.sh/ruff/rules/#pyupgrade-up
+ "PL", # pylint, https://docs.astral.sh/ruff/rules/#pylint-pl
+ "I", # isort, https://docs.astral.sh/ruff/rules/#isort-i
+ "RUF", # ruff, https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
+ "NPY", # numpy, https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
+ "F", # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f
+ # "ANN", # flake8-annotations,
https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
+ "PTH", # flake8-use-pathlib,
https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth
+ "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d
+]
+ignore = [
+ "PLR2004", # pylint: magic-value-comparison
+ "ANN401", # flake8-annotations: any-type
+ "D203", # pydocstyle: incorrect-blank-line-before-class
+ "D213", # pydocstyle: multi-line-summary-second-line
+]
+fixable = ["ALL"]
+unfixable = []
[tool.ruff.lint.per-file-ignores]
-"__init__.py" = ["F401"]
-"tests/*" = ["E741"]
+"__init__.py" = ["F401"] # pyflakes: unused-import
+"tests/*" = [
+ "E741", # pycodestyle: ambiguous-variable-name
+ "D100", # pydocstyle: undocumented-public-module
+ "D101", # pydocstyle: undocumented-public-class
+ "D103", # pydocstyle: undocumented-public-function
+ "D107", # pydocstyle: undocumented-public-init
+ "D205", # pydocstyle: missing-blank-line-after-summary
+]
[tool.ruff.lint.pylint]
max-args = 10
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+skip-magic-trailing-comma = false
+line-ending = "auto"
+docstring-code-format = false
+docstring-code-line-length = "dynamic"
+
[tool.cibuildwheel]
build-verbosity = 1
# only build up to cp312, cp312
# will be abi3 and can be used in future versions
-build = [
- "cp39-*",
- "cp310-*",
- "cp311-*",
- "cp312-*",
-]
-skip = [
- "*musllinux*"
-]
+build = ["cp39-*", "cp310-*", "cp311-*", "cp312-*"]
+skip = ["*musllinux*"]
# we only need to test on cp312
-test-skip = [
- "cp39-*",
- "cp310-*",
- "cp311-*",
-]
+test-skip = ["cp39-*", "cp310-*", "cp311-*"]
# focus on testing abi3 wheel
build-frontend = "build[uv]"
test-command = "pytest {package}/tests/python -vvs"
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index b3b070f..2381363 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -47,31 +47,31 @@ from . import testing
from . import _optional_torch_c_dlpack
__all__ = [
- "dtype",
+ "Array",
+ "DLDeviceType",
"Device",
+ "Device",
+ "Function",
+ "Map",
+ "Module",
"Object",
- "register_object",
- "register_global_func",
- "get_global_func",
- "remove_global_func",
- "init_ffi_api",
"Object",
"ObjectConvertible",
- "Function",
+ "Shape",
+ "Tensor",
+ "access_path",
"convert",
- "register_error",
- "Device",
"device",
- "DLDeviceType",
+ "dtype",
"from_dlpack",
- "Tensor",
- "Shape",
- "Array",
- "Map",
- "testing",
- "access_path",
+ "get_global_func",
+ "init_ffi_api",
+ "load_module",
+ "register_error",
+ "register_global_func",
+ "register_object",
+ "remove_global_func",
"serialization",
- "Module",
"system_lib",
- "load_module",
+ "testing",
]
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index cf311b2..cdeccbc 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -22,7 +22,7 @@ from typing import Any
from . import container, core
-def convert(value: Any) -> Any:
+def convert(value: Any) -> Any: # noqa: PLR0911
"""Convert a python object to ffi values.
Parameters
@@ -40,6 +40,7 @@ def convert(value: Any) -> Any:
Function arguments to ffi function calls are
automatically converted. So this function is mainly
only used in internal or testing scenarios.
+
"""
if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)):
return value
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 1664d98..a76d111 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -18,6 +18,7 @@
# pylint: disable=invalid-name
from enum import IntEnum
+from typing import Any, ClassVar
from . import core
@@ -54,11 +55,12 @@ class dtype(str):
----
This class subclasses str so it can be directly passed
into other array api's dtype arguments.
+
"""
__slots__ = ["__tvm_ffi_dtype__"]
- _NUMPY_DTYPE_TO_STR = {}
+ _NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}
def __new__(cls, content):
content = str(content)
@@ -70,8 +72,7 @@ class dtype(str):
return f"dtype('{self}')"
def with_lanes(self, lanes):
- """
- Create a new dtype with the given number of lanes.
+ """Create a new dtype with the given number of lanes.
Parameters
----------
@@ -82,6 +83,7 @@ class dtype(str):
-------
dtype
The new dtype with the given number of lanes.
+
"""
cdtype = core._create_dtype_from_tuple(
core.DataType,
@@ -128,7 +130,7 @@ try:
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
if hasattr(np, "float_"):
- dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64"
+ dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
except ImportError:
pass
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index b96e9d0..b8b1f8f 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -378,8 +378,8 @@ int64_t TorchDLPackTensorAllocatorPtr() {
"""
try:
# optionally import torch
- import torch
- from torch.utils import cpp_extension
+ import torch # noqa: PLC0415
+ from torch.utils import cpp_extension # noqa: PLC0415
include_paths = libinfo.include_paths()
extra_cflags = ["-O3"]
@@ -408,8 +408,7 @@ int64_t TorchDLPackTensorAllocatorPtr() {
pass
except Exception as e:
warnings.warn(
- f"Failed to load torch c dlpack extension: {e},"
- "EnvTensorAllocator will not be enabled."
+ f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator
will not be enabled."
)
return None
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index a9212b4..bea20a9 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -28,10 +28,11 @@ from .core import Device, DLDeviceType, Tensor, from_dlpack
class Shape(tuple, core.PyNativeObject):
"""Shape tuple that represents `ffi::Shape` returned by a ffi call.
- Note
+ Note:
----
This class subclasses `tuple` so it can be used in most places where
tuple is used in python array apis.
+
"""
def __new__(cls, content):
@@ -51,7 +52,7 @@ class Shape(tuple, core.PyNativeObject):
def device(device_type, index=None):
- """Construct a TVM FFI device with given device type and index
+ """Construct a TVM FFI device with given device type and index.
Parameters
----------
@@ -74,8 +75,9 @@ def device(device_type, index=None):
assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0)
assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0)
+
"""
return core._CLASS_DEVICE(device_type, index)
-__all__ = ["from_dlpack", "Tensor", "device", "Device", "DLDeviceType"]
+__all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"]
diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py
index 91a426b..e8aec10 100644
--- a/python/tvm_ffi/access_path.py
+++ b/python/tvm_ffi/access_path.py
@@ -18,13 +18,15 @@
"""Access path classes."""
from enum import IntEnum
-from typing import Any, List
+from typing import Any
from . import core
from .registry import register_object
class AccessKind(IntEnum):
+ """Kinds of access steps in an access path."""
+
ATTR = 0
ARRAY_ITEM = 1
MAP_ITEM = 2
@@ -35,14 +37,15 @@ class AccessKind(IntEnum):
@register_object("ffi.reflection.AccessStep")
class AccessStep(core.Object):
- """Access step container"""
+ """Access step container."""
@register_object("ffi.reflection.AccessPath")
class AccessPath(core.Object):
- """Access path container"""
+ """Access path container."""
def __init__(self) -> None:
+ """Disallow direct construction; use `AccessPath.root()` instead."""
super().__init__()
raise ValueError(
"AccessPath can't be initialized directly. "
@@ -51,21 +54,23 @@ class AccessPath(core.Object):
@staticmethod
def root() -> "AccessPath":
- """Create a root access path"""
+ """Create a root access path."""
return AccessPath._root()
def __eq__(self, other: Any) -> bool:
+ """Return whether two access paths are equal."""
if not isinstance(other, AccessPath):
return False
return self._path_equal(other)
def __ne__(self, other: Any) -> bool:
+ """Return whether two access paths are not equal."""
if not isinstance(other, AccessPath):
return True
return not self._path_equal(other)
def is_prefix_of(self, other: "AccessPath") -> bool:
- """Check if this access path is a prefix of another access path
+ """Check if this access path is a prefix of another access path.
Parameters
----------
@@ -76,11 +81,12 @@ class AccessPath(core.Object):
-------
bool
True if this access path is a prefix of the other access path,
False otherwise
+
"""
return self._is_prefix_of(other)
def attr(self, attr_key: str) -> "AccessPath":
- """Create an access path to the attribute of the current object
+ """Create an access path to the attribute of the current object.
Parameters
----------
@@ -91,11 +97,12 @@ class AccessPath(core.Object):
-------
AccessPath
The extended access path
+
"""
return self._attr(attr_key)
def attr_missing(self, attr_key: str) -> "AccessPath":
- """Create an access path that indicate an attribute is missing
+ """Create an access path that indicate an attribute is missing.
Parameters
----------
@@ -106,11 +113,12 @@ class AccessPath(core.Object):
-------
AccessPath
The extended access path
+
"""
return self._attr_missing(attr_key)
def array_item(self, index: int) -> "AccessPath":
- """Create an access path to the item of the current array
+ """Create an access path to the item of the current array.
Parameters
----------
@@ -121,11 +129,12 @@ class AccessPath(core.Object):
-------
AccessPath
The extended access path
+
"""
return self._array_item(index)
def array_item_missing(self, index: int) -> "AccessPath":
- """Create an access path that indicate an array item is missing
+ """Create an access path that indicate an array item is missing.
Parameters
----------
@@ -136,11 +145,12 @@ class AccessPath(core.Object):
-------
AccessPath
The extended access path
+
"""
return self._array_item_missing(index)
def map_item(self, key: Any) -> "AccessPath":
- """Create an access path to the item of the current map
+ """Create an access path to the item of the current map.
Parameters
----------
@@ -151,11 +161,12 @@ class AccessPath(core.Object):
-------
AccessPath
The extended access path
+
"""
return self._map_item(key)
def map_item_missing(self, key: Any) -> "AccessPath":
- """Create an access path that indicate a map item is missing
+ """Create an access path that indicate a map item is missing.
Parameters
----------
@@ -166,16 +177,18 @@ class AccessPath(core.Object):
-------
AccessPath
The extended access path
+
"""
return self._map_item_missing(key)
- def to_steps(self) -> List["AccessStep"]:
- """Convert the access path to a list of access steps
+ def to_steps(self) -> list["AccessStep"]:
+ """Convert the access path to a list of access steps.
Returns
-------
List[AccessStep]
The list of access steps
+
"""
return self._to_steps()
diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py
index 7e03680..bc31de4 100644
--- a/python/tvm_ffi/config.py
+++ b/python/tvm_ffi/config.py
@@ -14,51 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Config utilities for finding paths to lib and headers"""
+"""Config utilities for finding paths to lib and headers."""
import argparse
-import os
import sys
+from pathlib import Path
from . import libinfo
def find_windows_implib():
- libdir = os.path.dirname(libinfo.find_libtvm_ffi())
- implib = os.path.join(libdir, "tvm_ffi.lib")
- if not os.path.isfile(implib):
+ """Find and return the Windows import library path for tvm_ffi.lib."""
+ libdir = Path(libinfo.find_libtvm_ffi()).parent
+ implib = libdir / "tvm_ffi.lib"
+ if not implib.is_file():
raise RuntimeError(f"Cannot find imp lib {implib}")
- return implib
+ return str(implib)
-def __main__():
- """Main function"""
+def __main__(): # noqa: PLR0912
+ """Parse CLI args and print build and include configuration paths."""
parser = argparse.ArgumentParser(
description="Get various configuration information needed to compile
with tvm-ffi"
)
- parser.add_argument(
- "--includedir", action="store_true", help="Print include directory"
- )
+ parser.add_argument("--includedir", action="store_true", help="Print
include directory")
parser.add_argument(
"--dlpack-includedir",
action="store_true",
help="Print dlpack include directory",
)
- parser.add_argument(
- "--cmakedir", action="store_true", help="Print library directory"
- )
- parser.add_argument(
- "--sourcedir", action="store_true", help="Print source directory"
- )
- parser.add_argument(
- "--libfiles", action="store_true", help="Fully qualified library
filenames"
- )
+ parser.add_argument("--cmakedir", action="store_true", help="Print library
directory")
+ parser.add_argument("--sourcedir", action="store_true", help="Print source
directory")
+ parser.add_argument("--libfiles", action="store_true", help="Fully
qualified library filenames")
parser.add_argument("--libdir", action="store_true", help="Print library
directory")
parser.add_argument("--libs", action="store_true", help="Libraries to be
linked")
- parser.add_argument(
- "--cython-lib-path", action="store_true", help="Print cython path"
- )
+ parser.add_argument("--cython-lib-path", action="store_true", help="Print
cython path")
parser.add_argument("--cxxflags", action="store_true", help="Print cxx
flags")
parser.add_argument("--cflags", action="store_true", help="Print c flags")
parser.add_argument("--ldflags", action="store_true", help="Print ld
flags")
@@ -77,7 +68,7 @@ def __main__():
if args.cmakedir:
print(libinfo.find_cmake_path())
if args.libdir:
- print(os.path.dirname(libinfo.find_libtvm_ffi()))
+ print(Path(libinfo.find_libtvm_ffi()).parent)
if args.libfiles:
if sys.platform.startswith("win32"):
print(find_windows_implib())
@@ -102,7 +93,7 @@ def __main__():
print("-ltvm_ffi")
if args.ldflags:
if not sys.platform.startswith("win32"):
- print(f"-L{os.path.dirname(libinfo.find_libtvm_ffi())}")
+ print(f"-L{Path(libinfo.find_libtvm_ffi()).parent}")
if __name__ == "__main__":
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 8368cd4..9bb9f97 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -17,7 +17,8 @@
"""Container classes."""
import collections.abc
-from typing import Any, Mapping, Sequence
+from collections.abc import Mapping, Sequence
+from typing import Any
from . import _ffi_api, core
from .registry import register_object
@@ -26,7 +27,7 @@ __all__ = ["Array", "Map"]
def getitem_helper(obj, elem_getter, length, idx):
- """Helper function to implement a pythonic getitem function.
+ """Implement a pythonic __getitem__ helper.
Parameters
----------
@@ -46,6 +47,7 @@ def getitem_helper(obj, elem_getter, length, idx):
-------
result : object
The result of getitem
+
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
@@ -88,18 +90,23 @@ class Array(core.Object, collections.abc.Sequence):
a = tvm_ffi.convert([1, 2, 3])
assert isinstance(a, tvm_ffi.Array)
assert len(a) == 3
+
"""
def __init__(self, input_list: Sequence[Any]):
+ """Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
def __getitem__(self, idx):
+ """Return one element or a Python list for a slice."""
return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx)
def __len__(self):
+ """Return the number of elements in the array."""
return _ffi_api.ArraySize(self)
def __repr__(self):
+ """Return a string representation of the array."""
# exception safety handling for chandle=None
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
@@ -107,7 +114,7 @@ class Array(core.Object, collections.abc.Sequence):
class KeysView(collections.abc.KeysView):
- """Helper class to return keys view"""
+ """Helper class to return keys view."""
def __init__(self, backend_map):
self._backend_map = backend_map
@@ -130,7 +137,7 @@ class KeysView(collections.abc.KeysView):
class ValuesView(collections.abc.ValuesView):
- """Helper class to return values view"""
+ """Helper class to return values view."""
def __init__(self, backend_map):
self._backend_map = backend_map
@@ -150,7 +157,7 @@ class ValuesView(collections.abc.ValuesView):
class ItemsView(collections.abc.ItemsView):
- """Helper class to return items view"""
+ """Helper class to return items view."""
def __init__(self, backend_map):
self.backend_map = backend_map
@@ -196,9 +203,11 @@ class Map(core.Object, collections.abc.Mapping):
assert len(amap) == 2
assert amap["a"] == 1
assert amap["b"] == 2
+
"""
def __init__(self, input_dict: Mapping[Any, Any]):
+ """Construct a Map from a Python mapping."""
list_kvs = []
for k, v in input_dict.items():
list_kvs.append(k)
@@ -206,25 +215,31 @@ class Map(core.Object, collections.abc.Mapping):
self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
def __getitem__(self, k):
+ """Return the value for key `k` or raise KeyError."""
return _ffi_api.MapGetItem(self, k)
def __contains__(self, k):
+ """Return True if the map contains key `k`."""
return _ffi_api.MapCount(self, k) != 0
def keys(self):
+ """Return a dynamic view of the map's keys."""
return KeysView(self)
def values(self):
+ """Return a dynamic view of the map's values."""
return ValuesView(self)
def items(self):
- """Get the items from the map"""
+ """Get the items from the map."""
return ItemsView(self)
def __len__(self):
+ """Return the number of items in the map."""
return _ffi_api.MapSize(self)
def __iter__(self):
+ """Iterate over the map's keys."""
return iter(self.keys())
def get(self, key, default=None):
@@ -242,15 +257,13 @@ class Map(core.Object, collections.abc.Mapping):
-------
value: object
The result value.
+
"""
return self[key] if key in self else default
def __repr__(self):
+ """Return a string representation of the map."""
# exception safety handling for chandle=None
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
- return (
- "{"
- + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in
self.items()])
- + "}"
- )
+ return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in
self.items()]) + "}"
diff --git a/python/tvm_ffi/cpp/__init__.py b/python/tvm_ffi/cpp/__init__.py
index 632698f..ede2b54 100644
--- a/python/tvm_ffi/cpp/__init__.py
+++ b/python/tvm_ffi/cpp/__init__.py
@@ -14,5 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""C++ integration helpers for building and loading inline modules."""
from .load_inline import load_inline
diff --git a/python/tvm_ffi/cpp/load_inline.py
b/python/tvm_ffi/cpp/load_inline.py
index 6ce3d11..264a7bb 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -14,15 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Build and load inline C++/CUDA sources into a tvm_ffi Module using Ninja."""
import functools
-import glob
import hashlib
import os
import shutil
import subprocess
import sys
-from typing import Mapping, Optional, Sequence
+from collections.abc import Mapping, Sequence
+from pathlib import Path
+from typing import Optional
from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path,
find_libtvm_ffi
from tvm_ffi.module import Module, load_module
@@ -64,12 +66,13 @@ def _hash_sources(
def _maybe_write(path: str, content: str) -> None:
"""Write content to path if it does not already exist with the same
content."""
- if os.path.exists(path):
- with open(path, "r") as f:
+ p = Path(path)
+ if p.exists():
+ with p.open() as f:
existing_content = f.read()
if existing_content == content:
return
- with open(path, "w") as f:
+ with p.open("w") as f:
f.write(content)
@@ -82,23 +85,21 @@ def _find_cuda_home() -> Optional[str]:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
- cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
+ cuda_home = str(Path(nvcc_path).parent.parent)
else:
# Guess #3
if IS_WINDOWS:
- cuda_homes = glob.glob(
- "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
- )
+ cuda_root = Path("C:/Program Files/NVIDIA GPU Computing
Toolkit/CUDA")
+ cuda_homes = list(cuda_root.glob("v*.*"))
if len(cuda_homes) == 0:
cuda_home = ""
else:
- cuda_home = cuda_homes[0]
+ cuda_home = str(cuda_homes[0])
else:
cuda_home = "/usr/local/cuda"
- if not os.path.exists(cuda_home):
+ if not Path(cuda_home).exists():
raise RuntimeError(
- "Could not find CUDA installation. "
- "Please set CUDA_HOME environment variable."
+ "Could not find CUDA installation. Please set CUDA_HOME
environment variable."
)
return cuda_home
@@ -115,7 +116,6 @@ def _get_cuda_target() -> str:
flags.append(f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}")
return " ".join(flags)
else:
- #
try:
status = subprocess.run(
args=["nvidia-smi", "--query-gpu=compute_cap",
"--format=csv,noheader"],
@@ -134,14 +134,14 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
"""Locates the Developer Command Prompt and runs a command within its
environment."""
try:
# Path to vswhere.exe
- vswhere_path = os.path.join(
- os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)"),
- "Microsoft Visual Studio",
- "Installer",
- "vswhere.exe",
+ vswhere_path = str(
+ Path(os.environ.get("ProgramFiles(x86)", "C:\\Program Files
(x86)"))
+ / "Microsoft Visual Studio"
+ / "Installer"
+ / "vswhere.exe"
)
- if not os.path.exists(vswhere_path):
+ if not Path(vswhere_path).exists():
raise FileNotFoundError("vswhere.exe not found.")
# Find the Visual Studio installation path
@@ -164,11 +164,9 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
raise FileNotFoundError("No Visual Studio installation found.")
# Construct the path to the VsDevCmd.bat file
- vsdevcmd_path = os.path.join(
- vs_install_path, "Common7", "Tools", "VsDevCmd.bat"
- )
+ vsdevcmd_path = str(Path(vs_install_path) / "Common7" / "Tools" /
"VsDevCmd.bat")
- if not os.path.exists(vsdevcmd_path):
+ if not Path(vsdevcmd_path).exists():
raise FileNotFoundError(f"VsDevCmd.bat not found at:
{vsdevcmd_path}")
# Use cmd.exe to run the batch file and then your command.
@@ -180,7 +178,7 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
# Execute the command in a new shell
return subprocess.run(
- cmd_command, cwd=cwd, capture_output=capture_output, shell=True
+ cmd_command, check=False, cwd=cwd, capture_output=capture_output,
shell=True
)
except (FileNotFoundError, subprocess.CalledProcessError) as e:
@@ -191,7 +189,7 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
) from e
-def _generate_ninja_build(
+def _generate_ninja_build( # noqa: PLR0915
name: str,
build_dir: str,
with_cuda: bool,
@@ -204,8 +202,8 @@ def _generate_ninja_build(
default_include_paths = [find_include_path(), find_dlpack_include_path()]
tvm_ffi_lib = find_libtvm_ffi()
- tvm_ffi_lib_path = os.path.dirname(tvm_ffi_lib)
- tvm_ffi_lib_name = os.path.splitext(os.path.basename(tvm_ffi_lib))[0]
+ tvm_ffi_lib_path = str(Path(tvm_ffi_lib).parent)
+ tvm_ffi_lib_name = Path(tvm_ffi_lib).stem
if IS_WINDOWS:
default_cflags = [
"/std:c++17",
@@ -231,13 +229,13 @@ def _generate_ninja_build(
else:
default_cflags = ["-std=c++17", "-fPIC", "-O2"]
default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"]
- default_ldflags = ["-shared", "-L{}".format(tvm_ffi_lib_path),
"-ltvm_ffi"]
+ default_ldflags = ["-shared", f"-L{tvm_ffi_lib_path}", "-ltvm_ffi"]
if with_cuda:
# determine the compute capability of the current GPU
default_cuda_cflags += [_get_cuda_target()]
default_ldflags += [
- "-L{}".format(os.path.join(_find_cuda_home(), "lib64")),
+ "-L{}".format(str(Path(_find_cuda_home()) / "lib64")),
"-lcudart",
]
@@ -245,7 +243,7 @@ def _generate_ninja_build(
cuda_cflags = default_cuda_cflags + [flag.strip() for flag in
extra_cuda_cflags]
ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags]
include_paths = default_include_paths + [
- os.path.abspath(path) for path in extra_include_paths
+ str(Path(path).resolve()) for path in extra_include_paths
]
# append include paths
@@ -256,12 +254,10 @@ def _generate_ninja_build(
# flags
ninja = []
ninja.append("ninja_required_version = 1.3")
- ninja.append(
- "cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))
- )
+ ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS
else "c++")))
ninja.append("cflags = {}".format(" ".join(cflags)))
if with_cuda:
- ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin",
"nvcc")))
+ ninja.append("nvcc = {}".format(str(Path(_find_cuda_home()) / "bin" /
"nvcc")))
ninja.append("cuda_cflags = {}".format(" ".join(cuda_cflags)))
ninja.append("ldflags = {}".format(" ".join(ldflags)))
@@ -296,24 +292,22 @@ def _generate_ninja_build(
# build targets
ninja.append(
"build main.o: compile {}".format(
- os.path.abspath(os.path.join(build_dir, "main.cpp")).replace(":",
"$:")
+ str((Path(build_dir) / "main.cpp").resolve()).replace(":", "$:")
)
)
if with_cuda:
ninja.append(
"build cuda.o: compile_cuda {}".format(
- os.path.abspath(os.path.join(build_dir,
"cuda.cu")).replace(":", "$:")
+ str((Path(build_dir) / "cuda.cu").resolve()).replace(":", "$:")
)
)
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- ninja.append(
- "build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda
else "")
- )
+ ninja.append("build {}{}: link main.o{}".format(name, ext, " cuda.o" if
with_cuda else ""))
ninja.append("")
# default target
- ninja.append("default {}{}".format(name, ext))
+ ninja.append(f"default {name}{ext}")
ninja.append("")
return "\n".join(ninja)
@@ -325,18 +319,16 @@ def _build_ninja(build_dir: str) -> None:
if num_workers is not None:
command += ["-j", num_workers]
if IS_WINDOWS:
- status = _run_command_in_dev_prompt(
- args=command, cwd=build_dir, capture_output=True
- )
+ status = _run_command_in_dev_prompt(args=command, cwd=build_dir,
capture_output=True)
else:
- status = subprocess.run(args=command, cwd=build_dir,
capture_output=True)
+ status = subprocess.run(check=False, args=command, cwd=build_dir,
capture_output=True)
if status.returncode != 0:
- msg = ["ninja exited with status {}".format(status.returncode)]
+ msg = [f"ninja exited with status {status.returncode}"]
encoding = "oem" if IS_WINDOWS else "utf-8"
if status.stdout:
- msg.append("stdout:\n{}".format(status.stdout.decode(encoding)))
+ msg.append(f"stdout:\n{status.stdout.decode(encoding)}")
if status.stderr:
- msg.append("stderr:\n{}".format(status.stderr.decode(encoding)))
+ msg.append(f"stderr:\n{status.stderr.decode(encoding)}")
raise RuntimeError("\n".join(msg))
@@ -401,7 +393,6 @@ def load_inline(
Parameters
----------
-
name: str
The name of the tvm ffi module.
cpp_sources: Sequence[str] | str, optional
@@ -483,6 +474,7 @@ def load_inline(
y = torch.empty_like(x)
mod.add_one_cpu(x, y)
torch.testing.assert_close(x + 1, y)
+
"""
if cpp_sources is None:
cpp_sources = []
@@ -518,7 +510,7 @@ def load_inline(
# determine the cache dir for the built module
if build_directory is None:
build_directory = os.environ.get(
- "TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi")
+ "TVM_FFI_CACHE_DIR", str(Path("~/.cache/tvm-ffi").expanduser())
)
source_hash: str = _hash_sources(
cpp_source,
@@ -529,12 +521,10 @@ def load_inline(
extra_ldflags,
extra_include_paths,
)
- build_dir: str = os.path.join(
- build_directory, "{}_{}".format(name, source_hash)
- )
+ build_dir: str = str(Path(build_directory) / f"{name}_{source_hash}")
else:
- build_dir = os.path.abspath(build_directory)
- os.makedirs(build_dir, exist_ok=True)
+ build_dir = str(Path(build_directory).resolve())
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
# generate build.ninja
ninja_source = _generate_ninja_build(
@@ -547,18 +537,16 @@ def load_inline(
extra_include_paths=extra_include_paths,
)
- with FileLock(os.path.join(build_dir, "lock")):
+ with FileLock(str(Path(build_dir) / "lock")):
# write source files and build.ninja if they do not already exist
- _maybe_write(os.path.join(build_dir, "main.cpp"), cpp_source)
+ _maybe_write(str(Path(build_dir) / "main.cpp"), cpp_source)
if with_cuda:
- _maybe_write(os.path.join(build_dir, "cuda.cu"), cuda_source)
- _maybe_write(os.path.join(build_dir, "build.ninja"), ninja_source)
+ _maybe_write(str(Path(build_dir) / "cuda.cu"), cuda_source)
+ _maybe_write(str(Path(build_dir) / "build.ninja"), ninja_source)
# build the module
_build_ninja(build_dir)
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- return load_module(
- os.path.abspath(os.path.join(build_dir, "{}{}".format(name, ext)))
- )
+ return load_module(str((Path(build_dir) / f"{name}{ext}").resolve()))
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index cec6956..28788ef 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -26,7 +26,7 @@ from . import core
def _parse_traceback(traceback):
- """Parse the traceback string into a list of (filename, lineno, func)
+ """Parse the traceback string into a list of (filename, lineno, func).
Parameters
----------
@@ -37,6 +37,7 @@ def _parse_traceback(traceback):
-------
result : List[Tuple[str, int, str]]
The list of (filename, lineno, func)
+
"""
pattern = r'File "(.+?)", line (\d+), in (.+)'
result = []
@@ -54,11 +55,10 @@ def _parse_traceback(traceback):
class TracebackManager:
- """
- Helper to manage traceback generation
- """
+ """Helper to manage traceback generation."""
def __init__(self):
+ """Initialize the traceback manager and its cache."""
self._code_cache = {}
def _get_cached_code_object(self, filename, lineno, func):
@@ -84,7 +84,7 @@ class TracebackManager:
return code_object
def _create_frame(self, filename, lineno, func):
- """Create a frame object from the filename, lineno, and func"""
+ """Create a frame object from the filename, lineno, and func."""
code_object = self._get_cached_code_object(filename, lineno, func)
# call into get frame, but changes the context so the code
# points to the correct frame
@@ -93,7 +93,7 @@ class TracebackManager:
return eval(code_object, context, context)
def append_traceback(self, tb, filename, lineno, func):
- """Append a traceback to the given traceback
+ """Append a traceback to the given traceback.
Parameters
----------
@@ -110,6 +110,7 @@ class TracebackManager:
-------
new_tb : types.TracebackType
The new traceback with the appended frame.
+
"""
frame = self._create_frame(filename, lineno, func)
return types.TracebackType(tb, frame, frame.f_lasti, lineno)
@@ -119,7 +120,7 @@ _TRACEBACK_MANAGER = TracebackManager()
def _with_append_traceback(py_error, traceback):
- """Append the traceback to the py_error and return it"""
+ """Append the traceback to the py_error and return it."""
tb = py_error.__traceback__
for filename, lineno, func in reversed(_parse_traceback(traceback)):
tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func)
@@ -127,7 +128,7 @@ def _with_append_traceback(py_error, traceback):
def _traceback_to_str(tb):
- """Convert the traceback to a string"""
+ """Convert the traceback to a string."""
lines = []
while tb is not None:
frame = tb.tb_frame
@@ -169,13 +170,14 @@ def register_error(name_or_cls=None, cls=None):
err_inst = tvm.error.create_ffi_error("MyError: xyz")
assert isinstance(err_inst, MyError)
+
"""
if callable(name_or_cls):
cls = name_or_cls
name_or_cls = cls.__name__
def register(mycls):
- """internal register function"""
+ """Register the error class name with the FFI core."""
err_name = name_or_cls if isinstance(name_or_cls, str) else
mycls.__name__
core.ERROR_NAME_TO_TYPE[err_name] = mycls
core.ERROR_TYPE_TO_NAME[mycls] = err_name
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index 8325c35..b707f2b 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Utilities to locate tvm_ffi libraries, headers, and helper include paths."""
-import glob
import os
import sys
+from pathlib import Path
def split_env_var(env_var, split):
- """Splits environment variable string.
+ """Split an environment variable string.
Parameters
----------
@@ -35,6 +36,7 @@ def split_env_var(env_var, split):
-------
splits : list(string)
If env_var exists, split env_var. Otherwise, empty list.
+
"""
if os.environ.get(env_var, None):
return [p.strip() for p in os.environ[env_var].split(split)]
@@ -42,12 +44,12 @@ def split_env_var(env_var, split):
def get_dll_directories():
- """Get the possible dll directories"""
- ffi_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
- dll_path = [os.path.join(ffi_dir, "lib")]
- dll_path += [os.path.join(ffi_dir, "..", "..", "build", "lib")]
+ """Get the possible dll directories."""
+ ffi_dir = Path(__file__).expanduser().resolve().parent
+ dll_path = [ffi_dir / "lib"]
+ dll_path += [ffi_dir / ".." / ".." / "build" / "lib"]
# in source build from parent if needed
- dll_path += [os.path.join(ffi_dir, "..", "..", "..", "build", "lib")]
+ dll_path += [ffi_dir / ".." / ".." / ".." / "build" / "lib"]
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":"))
@@ -57,7 +59,7 @@ def get_dll_directories():
dll_path.extend(split_env_var("PATH", ":"))
elif sys.platform.startswith("win32"):
dll_path.extend(split_env_var("PATH", ";"))
- return [os.path.abspath(x) for x in dll_path if os.path.isdir(x)]
+ return [str(Path(x).resolve()) for x in dll_path if Path(x).is_dir()]
def find_libtvm_ffi():
@@ -71,13 +73,11 @@ def find_libtvm_ffi():
lib_dll_names = ["libtvm_ffi.so"]
name = lib_dll_names
- lib_dll_path = [os.path.join(p, name) for name in lib_dll_names for p in
dll_path]
- lib_found = [p for p in lib_dll_path if os.path.exists(p) and
os.path.isfile(p)]
+ lib_dll_path = [str(Path(p) / name) for name in lib_dll_names for p in
dll_path]
+ lib_found = [p for p in lib_dll_path if Path(p).exists() and
Path(p).is_file()]
if not lib_found:
- raise RuntimeError(
- f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}"
- )
+ raise RuntimeError(f"Cannot find library: {name}\nList of
candidates:\n{lib_dll_path}")
return lib_found[0]
@@ -85,11 +85,11 @@ def find_libtvm_ffi():
def find_source_path():
"""Find packaged source home path."""
candidates = [
- os.path.join(os.path.dirname(os.path.realpath(__file__))),
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", ".."),
+ str(Path(__file__).resolve().parent),
+ str(Path(__file__).resolve().parent / ".." / ".."),
]
for candidate in candidates:
- if os.path.isdir(os.path.join(candidate, "cmake")):
+ if Path(candidate, "cmake").is_dir():
return candidate
raise RuntimeError("Cannot find home path.")
@@ -97,11 +97,11 @@ def find_source_path():
def find_cmake_path():
"""Find the preferred cmake path."""
candidates = [
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "cmake"),
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",
"cmake"),
+ str(Path(__file__).resolve().parent / "cmake"),
+ str(Path(__file__).resolve().parent / ".." / ".." / "cmake"),
]
for candidate in candidates:
- if os.path.isdir(candidate):
+ if Path(candidate).is_dir():
return candidate
raise RuntimeError("Cannot find cmake path.")
@@ -109,13 +109,11 @@ def find_cmake_path():
def find_include_path():
"""Find header files for C compilation."""
candidates = [
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"),
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"
- ),
+ str(Path(__file__).resolve().parent / "include"),
+ str(Path(__file__).resolve().parent / ".." / ".." / "include"),
]
for candidate in candidates:
- if os.path.isdir(candidate):
+ if Path(candidate).is_dir():
return candidate
raise RuntimeError("Cannot find include path.")
@@ -123,33 +121,26 @@ def find_include_path():
def find_python_helper_include_path():
"""Find header files for C compilation."""
candidates = [
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"),
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "cython"),
+ str(Path(__file__).resolve().parent / "include"),
+ str(Path(__file__).resolve().parent / "cython"),
]
for candidate in candidates:
- if os.path.isfile(os.path.join(candidate, "tvm_ffi_python_helpers.h")):
+ if Path(candidate, "tvm_ffi_python_helpers.h").is_file():
return candidate
raise RuntimeError("Cannot find python helper include path.")
def find_dlpack_include_path():
"""Find dlpack header files for C compilation."""
- install_include_path = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "include"
- )
- if os.path.isdir(os.path.join(install_include_path, "dlpack")):
- return install_include_path
-
- source_include_path = os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "3rdparty",
- "dlpack",
- "include",
+ install_include_path = Path(__file__).resolve().parent / "include"
+ if (install_include_path / "dlpack").is_dir():
+ return str(install_include_path)
+
+ source_include_path = (
+ Path(__file__).resolve().parent / ".." / ".." / "3rdparty" / "dlpack"
/ "include"
)
- if os.path.isdir(source_include_path):
- return source_include_path
+ if source_include_path.is_dir():
+ return str(source_include_path)
raise RuntimeError("Cannot find include path.")
@@ -157,13 +148,13 @@ def find_dlpack_include_path():
def find_cython_lib():
"""Find the path to tvm cython."""
path_candidates = [
- os.path.dirname(os.path.realpath(__file__)),
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",
"build"),
+ Path(__file__).resolve().parent,
+ Path(__file__).resolve().parent / ".." / ".." / "build",
]
suffixes = "pyd" if sys.platform.startswith("win32") else "so"
for candidate in path_candidates:
- for path in glob.glob(os.path.join(candidate, f"core*.{suffixes}")):
- return os.path.abspath(path)
+ for path in Path(candidate).glob(f"core*.{suffixes}"):
+ return str(Path(path).resolve())
raise RuntimeError("Cannot find tvm cython path.")
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index fbfb35d..335e262 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -22,7 +22,7 @@ from enum import IntEnum
from . import _ffi_api, core
from .registry import register_object
-__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"]
+__all__ = ["Module", "ModulePropertyMask", "load_module", "system_lib"]
class ModulePropertyMask(IntEnum):
@@ -37,7 +37,7 @@ class ModulePropertyMask(IntEnum):
class Module(core.Object):
"""Module container for dynamically loaded Module.
- Example
+ Example:
-------
.. code-block:: python
@@ -48,9 +48,10 @@ class Module(core.Object):
# you can use mod.func_name to call the exported function
mod.func_name(*args)
- See Also
+ See Also:
--------
:py:func:`tvm_ffi.load_module`
+
"""
# constant for entry function name
@@ -63,17 +64,22 @@ class Module(core.Object):
@property
def imports(self):
- """Get imported modules
+ """Get imported modules.
Returns
- ----------
+ -------
modules : list of Module
The module
+
"""
return self.imports_
def implements_function(self, name, query_imports=False):
- """Returns True if the module has a definition for the global function
with name. Note
+ """Return True if the module defines a global function.
+
+ Note
+ ----
+ that has_function(name) does not imply get_function(name) is non-null
since the module
that has_function(name) does not imply get_function(name) is non-null
since the module
may be, eg, a CSourceModule which cannot supply a packed-func
implementation of the function
without further compilation. However, get_function(name) non null
should always imply
@@ -91,6 +97,7 @@ class Module(core.Object):
-------
b : Bool
True if module (or one of its imports) has a definition for name.
+
"""
return _ffi_api.ModuleImplementsFunction(self, name, query_imports)
@@ -118,6 +125,7 @@ class Module(core.Object):
-------
f : tvm_ffi.Function
The result function.
+
"""
func = _ffi_api.ModuleGetFunction(self, name, query_imports)
if func is None:
@@ -131,15 +139,18 @@ class Module(core.Object):
----------
module : tvm.runtime.Module
The other module.
+
"""
_ffi_api.ModuleImportModule(self, module)
def __getitem__(self, name):
+ """Return function by name using item access (module["func"])."""
if not isinstance(name, str):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
+ """Call the module's entry function (`main`)."""
# pylint: disable=not-callable
return self.main(*args)
@@ -155,6 +166,7 @@ class Module(core.Object):
-------
source : str
The result source code.
+
"""
return _ffi_api.ModuleInspectSource(self, fmt)
@@ -169,40 +181,44 @@ class Module(core.Object):
-------
mask : int
Bitmask of runtime module property
+
"""
return _ffi_api.ModuleGetPropertyMask(self)
def is_binary_serializable(self):
- """Module 'binary serializable', save_to_bytes is supported.
+ """Return whether the module is binary serializable (supports
save_to_bytes).
Returns
-------
b : Bool
True if the module is binary serializable.
+
"""
return (self.get_property_mask() &
ModulePropertyMask.BINARY_SERIALIZABLE) != 0
def is_runnable(self):
- """Module 'runnable', get_function is supported.
+ """Return whether the module is runnable (supports get_function).
Returns
-------
b : Bool
True if the module is runnable.
+
"""
return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0
def is_compilation_exportable(self):
- """Module 'compilation exportable', write_to_file is supported for
object or source.
+ """Return whether the module is compilation exportable.
+
+ write_to_file is supported for object or source.
Returns
-------
b : Bool
True if the module is compilation exportable.
+
"""
- return (
- self.get_property_mask() &
ModulePropertyMask.COMPILATION_EXPORTABLE
- ) != 0
+ return (self.get_property_mask() &
ModulePropertyMask.COMPILATION_EXPORTABLE) != 0
def clear_imports(self):
"""Remove all imports of the module."""
@@ -221,6 +237,7 @@ class Module(core.Object):
See Also
--------
runtime.Module.export_library : export the module to shared library.
+
"""
_ffi_api.ModuleWriteToFile(self, file_name, fmt)
@@ -245,6 +262,7 @@ def system_lib(symbol_prefix=""):
-------
module : runtime.Module
The system-wide library module.
+
"""
return _ffi_api.SystemLib(symbol_prefix)
@@ -272,5 +290,6 @@ def load_module(path):
See Also
--------
:py:class:`tvm_ffi.Module`
+
"""
return _ffi_api.ModuleLoadFromFile(path)
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index f31dea3..60c8ded 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -25,7 +25,7 @@ _SKIP_UNKNOWN_OBJECTS = False
def register_object(type_key=None):
- """register object type.
+ """Register object type.
Parameters
----------
@@ -42,16 +42,17 @@ def register_object(type_key=None):
@tvm_ffi.register_object("test.MyObject")
class MyObject(Object):
pass
+
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
- """internal register function"""
+ """Register the object type with the FFI core."""
type_index = core._object_type_key_to_index(object_name)
if type_index is None:
if _SKIP_UNKNOWN_OBJECTS:
return cls
- raise ValueError("Cannot find object type index for %s" %
object_name)
+ raise ValueError(f"Cannot find object type index for
{object_name}")
core._add_class_attrs_by_reflection(type_index, cls)
core._register_object_by_index(type_index, cls)
return cls
@@ -63,7 +64,7 @@ def register_object(type_key=None):
def register_global_func(func_name, f=None, override=False):
- """Register global function
+ """Register global function.
Parameters
----------
@@ -104,6 +105,7 @@ def register_global_func(func_name, f=None, override=False):
--------
:py:func:`tvm_ffi.get_global_func`
:py:func:`tvm_ffi.remove_global_func`
+
"""
if callable(func_name):
f = func_name
@@ -113,7 +115,7 @@ def register_global_func(func_name, f=None, override=False):
raise ValueError("expect string function name")
def register(myf):
- """internal register function"""
+ """Register the global function with the FFI core."""
return core._register_global_func(func_name, myf, override)
if f:
@@ -122,7 +124,7 @@ def register_global_func(func_name, f=None, override=False):
def get_global_func(name, allow_missing=False):
- """Get a global function by name
+ """Get a global function by name.
Parameters
----------
@@ -140,6 +142,7 @@ def get_global_func(name, allow_missing=False):
See Also
--------
:py:func:`tvm_ffi.register_global_func`
+
"""
return core._get_global_func(name, allow_missing)
@@ -151,6 +154,7 @@ def list_global_func_names():
-------
names : list
List of global functions names.
+
"""
name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")()
num_names = name_functor(-1)
@@ -158,18 +162,19 @@ def list_global_func_names():
def remove_global_func(name):
- """Remove a global function by name
+ """Remove a global function by name.
Parameters
----------
name : str
The name of the global function
+
"""
get_global_func("ffi.FunctionRemoveGlobal")(name)
def init_ffi_api(namespace, target_module_name=None):
- """Initialize register ffi api functions into a given module
+ """Initialize register ffi api functions into a given module.
Parameters
----------
@@ -181,7 +186,6 @@ def init_ffi_api(namespace, target_module_name=None):
Examples
--------
-
A typical usage pattern is to create a _ffi_api.py file to register
the functions under a given module. The following
code populates all registered global functions
@@ -195,6 +199,7 @@ def init_ffi_api(namespace, target_module_name=None):
import tvm_ffi
tvm_ffi.init_ffi_api("mypackage", __name__)
+
"""
target_module_name = target_module_name if target_module_name else
namespace
@@ -219,10 +224,10 @@ def init_ffi_api(namespace, target_module_name=None):
__all__ = [
- "register_object",
- "register_global_func",
"get_global_func",
+ "init_ffi_api",
"list_global_func_names",
+ "register_global_func",
+ "register_object",
"remove_global_func",
- "init_ffi_api",
]
diff --git a/python/tvm_ffi/serialization.py b/python/tvm_ffi/serialization.py
index e5367d9..803d533 100644
--- a/python/tvm_ffi/serialization.py
+++ b/python/tvm_ffi/serialization.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Serialization related utilities to enable some object can be pickled"""
+"""Serialization related utilities to enable some object can be pickled."""
from typing import Any, Optional
@@ -22,8 +22,7 @@ from . import _ffi_api
def to_json_graph_str(obj: Any, metadata: Optional[dict] = None):
- """
- Dump an object to a JSON graph string.
+ """Dump an object to a JSON graph string.
The JSON graph string is a string representation of of the object
graph includes the reference information of same objects, which can
@@ -41,13 +40,13 @@ def to_json_graph_str(obj: Any, metadata: Optional[dict] =
None):
-------
json_str : str
The JSON graph string.
+
"""
return _ffi_api.ToJSONGraphString(obj, metadata)
def from_json_graph_str(json_str: str):
- """
- Load an object from a JSON graph string.
+ """Load an object from a JSON graph string.
The JSON graph string is a string representation of of the object
graph that also includes the reference information.
@@ -61,6 +60,7 @@ def from_json_graph_str(json_str: str):
-------
obj : Any
The loaded object.
+
"""
return _ffi_api.FromJSONGraphString(json_str)
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 084dca8..81cbabe 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -25,7 +25,8 @@ from ._tensor import device
class StreamContext:
- """StreamContext represents a stream context in the ffi system.
+ """Represent a stream context in the FFI system.
+
StreamContext helps setup ffi environment stream by python `with`
statement.
When entering `with` scope, it caches the current environment stream and
setup the given new stream.
@@ -42,19 +43,23 @@ class StreamContext:
See Also
--------
:py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream`
+
"""
def __init__(self, device: core.Device, stream: Union[int, c_void_p]):
+ """Initialize a stream context with a device and stream handle."""
self.device_type = device.dlpack_device_type()
self.device_id = device.index
self.stream = stream
def __enter__(self):
+ """Enter the context and set the current stream."""
self.prev_stream = core._env_set_current_stream(
self.device_type, self.device_id, self.stream
)
def __exit__(self, *args):
+ """Exit the context and restore the previous stream."""
self.prev_stream = core._env_set_current_stream(
self.device_type, self.device_id, self.prev_stream
)
@@ -64,10 +69,14 @@ try:
import torch
class TorchStreamContext:
+ """Context manager that syncs Torch and FFI stream contexts."""
+
def __init__(self, context: Optional[Any]):
+ """Initialize with an optional Torch stream/graph context
wrapper."""
self.torch_context = context
def __enter__(self):
+ """Enter both Torch and FFI stream contexts."""
if self.torch_context:
self.torch_context.__enter__()
current_stream = torch.cuda.current_stream()
@@ -77,13 +86,14 @@ try:
self.ffi_context.__enter__()
def __exit__(self, *args):
+ """Exit both Torch and FFI stream contexts."""
if self.torch_context:
self.torch_context.__exit__(*args)
self.ffi_context.__exit__(*args)
def use_torch_stream(context: Optional[Any] = None):
- """
- Create a ffi stream context with given torch stream,
+ """Create an FFI stream context with a Torch stream or graph.
+
cuda graph or current stream if `None` provided.
Parameters
@@ -111,18 +121,19 @@ try:
Note
----
When working with raw cudaStream_t handle, using
:py:func:`tvm_ffi.use_raw_stream` instead.
+
"""
return TorchStreamContext(context)
except ImportError:
def use_torch_stream(context: Optional[Any] = None):
+ """Raise an informative error when Torch is unavailable."""
raise ImportError("Cannot import torch")
def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]):
- """
- Create a ffi stream context with given device and stream handle.
+ """Create a ffi stream context with given device and stream handle.
Parameters
----------
@@ -140,6 +151,7 @@ def use_raw_stream(device: core.Device, stream: Union[int,
c_void_p]):
Note
----
When working with torch stram or cuda graph, using
:py:func:`tvm_ffi.use_torch_stream` instead.
+
"""
if not isinstance(stream, (int, c_void_p)):
raise ValueError(
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 843a10c..3c173dc 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -23,21 +23,16 @@ from .registry import register_object
@register_object("testing.TestObjectBase")
class TestObjectBase(Object):
- """
- Test object base class.
- """
+ """Test object base class."""
@register_object("testing.TestObjectDerived")
class TestObjectDerived(TestObjectBase):
- """
- Test object derived class.
- """
+ """Test object derived class."""
def create_object(type_key: str, **kwargs) -> Object:
- """
- Make an object by reflection.
+ """Make an object by reflection.
Parameters
----------
@@ -55,6 +50,7 @@ def create_object(type_key: str, **kwargs) -> Object:
----
This function is only used for testing purposes and should
not be used in other cases.
+
"""
args = [type_key]
for k, v in kwargs.items():
diff --git a/python/tvm_ffi/utils/__init__.py b/python/tvm_ffi/utils/__init__.py
index 543bd0f..896001e 100644
--- a/python/tvm_ffi/utils/__init__.py
+++ b/python/tvm_ffi/utils/__init__.py
@@ -14,5 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Utilities used by the tvm_ffi Python package."""
from .lockfile import FileLock
diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py
index 55ab41f..b317f04 100644
--- a/python/tvm_ffi/utils/lockfile.py
+++ b/python/tvm_ffi/utils/lockfile.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Simple cross-platform advisory file lock utilities."""
import os
import sys
@@ -27,34 +28,33 @@ else:
class FileLock:
- """
- A cross-platform file locking mechanism using Python's standard library.
+ """Provide a cross-platform file locking mechanism using Python's stdlib.
+
This class implements an advisory lock, which must be respected by all
cooperating processes.
"""
def __init__(self, lock_file_path):
+ """Initialize a file lock using the given lock file path."""
self.lock_file_path = lock_file_path
self._file_descriptor = None
def __enter__(self):
- """
- Context manager protocol: acquire the lock upon entering the 'with'
block.
- This method will block indefinitely until the lock is acquired.
+ """Acquire the lock upon entering the context.
+
+ This method blocks until the lock is acquired.
"""
self.blocking_acquire()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Context manager protocol: release the lock upon exiting the 'with'
block.
- """
+ """Context manager protocol: release the lock upon exiting the 'with'
block."""
self.release()
return False # Propagate exceptions, if any
def acquire(self):
- """
- Acquires an exclusive, non-blocking lock on the file.
+ """Acquire an exclusive, non-blocking lock on the file.
+
Returns True if the lock was acquired, False otherwise.
"""
try:
@@ -64,12 +64,10 @@ class FileLock:
)
msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1)
else: # Unix-like systems
- self._file_descriptor = os.open(
- self.lock_file_path, os.O_WRONLY | os.O_CREAT
- )
+ self._file_descriptor = os.open(self.lock_file_path,
os.O_WRONLY | os.O_CREAT)
fcntl.flock(self._file_descriptor, fcntl.LOCK_EX |
fcntl.LOCK_NB)
return True
- except (IOError, BlockingIOError):
+ except (OSError, BlockingIOError):
if self._file_descriptor is not None:
os.close(self._file_descriptor)
self._file_descriptor = None
@@ -81,13 +79,13 @@ class FileLock:
raise RuntimeError(f"An unexpected error occurred: {e}")
def blocking_acquire(self, timeout=None, poll_interval=0.1):
- """
- Waits until an exclusive lock can be acquired, with an optional
timeout.
+ """Wait until an exclusive lock can be acquired, with an optional
timeout.
Args:
timeout (float): The maximum time to wait for the lock in seconds.
A value of None means wait indefinitely.
poll_interval (float): The time to wait between lock attempts in
seconds.
+
"""
start_time = time.time()
while True:
@@ -103,9 +101,7 @@ class FileLock:
time.sleep(poll_interval)
def release(self):
- """
- Releases the lock and closes the file descriptor.
- """
+ """Releases the lock and closes the file descriptor."""
if self._file_descriptor is not None:
if sys.platform == "win32":
msvcrt.locking(self._file_descriptor, msvcrt.LK_UNLCK, 1)
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index 48df954..9fcce8b 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -18,9 +18,9 @@
import argparse
import fnmatch
-import os
import subprocess
import sys
+from pathlib import Path
header_cstyle = """
/*
@@ -172,7 +172,7 @@ SKIP_LIST = []
def should_skip_file(filepath):
- """Check if file should be skipped based on SKIP_LIST"""
+ """Check if file should be skipped based on SKIP_LIST."""
for pattern in SKIP_LIST:
if fnmatch.fnmatch(filepath, pattern):
return True
@@ -180,17 +180,15 @@ def should_skip_file(filepath):
def get_git_files():
- """Get list of files tracked by git"""
+ """Get list of files tracked by git."""
try:
result = subprocess.run(
- ["git", "ls-files"], capture_output=True, text=True,
cwd=os.getcwd()
+ ["git", "ls-files"], check=False, capture_output=True, text=True,
cwd=Path.cwd()
)
if result.returncode == 0:
return [line.strip() for line in result.stdout.split("\n") if
line.strip()]
else:
- print(
- "Error: Could not get git files. Make sure you're in a git
repository."
- )
+ print("Error: Could not get git files. Make sure you're in a git
repository.")
print("Git command failed:", result.stderr.strip())
return None
except FileNotFoundError:
@@ -211,12 +209,12 @@ def copyright_line(line):
def check_header(fname, header):
- """Check header status of file without modifying it"""
- if not os.path.exists(fname):
+ """Check header status of file without modifying it."""
+ if not Path(fname).exists():
print(f"ERROR: Cannot find {fname}")
return False
- lines = open(fname).readlines()
+ lines = Path(fname).open().readlines()
has_asf_header = False
has_copyright = False
@@ -243,7 +241,7 @@ def check_header(fname, header):
def collect_files():
- """Collect all files that need header checking from git"""
+ """Collect all files that need header checking from git."""
files = []
# Get files from git (required)
@@ -261,26 +259,25 @@ def collect_files():
# Check if this file type is supported
suffix = git_file.split(".")[-1] if "." in git_file else ""
- basename = os.path.basename(git_file)
+ basename = Path(git_file).name
if (
suffix in FMT_MAP
or basename == "gradle.properties"
- or suffix == ""
- and basename in ["CMakeLists", "Makefile"]
+ or (suffix == "" and basename in ["CMakeLists", "Makefile"])
):
files.append(git_file)
return files
-def add_header(fname, header):
- """Add header to file"""
- if not os.path.exists(fname):
- print("Cannot find %s ..." % fname)
+def add_header(fname, header): # noqa: PLR0912
+ """Add header to file."""
+ if not Path(fname).exists():
+ print(f"Cannot find {fname} ...")
return
- lines = open(fname).readlines()
+ lines = Path(fname).open().readlines()
has_asf_header = False
has_copyright = False
@@ -295,7 +292,7 @@ def add_header(fname, header):
if has_asf_header and not has_copyright:
return
- with open(fname, "w") as outfile:
+ with Path(fname).open("w") as outfile:
skipline = False
if not lines:
skipline = False # File is enpty
@@ -318,12 +315,12 @@ def add_header(fname, header):
outfile.write(header + "\n\n")
outfile.write("".join(lines))
if not has_asf_header:
- print("Add header to %s" % fname)
+ print(f"Add header to {fname}")
if has_copyright:
- print("Removed copyright line from %s" % fname)
+ print(f"Removed copyright line from {fname}")
-def main():
+def main(): # noqa: PLR0911, PLR0912
parser = argparse.ArgumentParser(
description="Check and fix ASF headers in source files tracked by git",
formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -400,7 +397,7 @@ Examples:
for fname in files:
processed_count += 1
suffix = fname.split(".")[-1] if "." in fname else ""
- basename = os.path.basename(fname)
+ basename = Path(fname).name
# Determine header type
if suffix in FMT_MAP:
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index d666470..9517168 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -16,9 +16,9 @@
# under the License.
"""Helper tool to check file types that are allowed to checkin."""
-import os
import subprocess
import sys
+from pathlib import Path
# List of file types we allow
ALLOW_EXTENSION = {
@@ -128,12 +128,13 @@ def filename_allowed(name):
-------
allowed : bool
Whether the filename is allowed.
+
"""
arr = name.rsplit(".", 1)
if arr[-1] in ALLOW_EXTENSION:
return True
- if os.path.basename(name) in ALLOW_FILE_NAME:
+ if Path(name).name in ALLOW_FILE_NAME:
return True
if name.startswith("3rdparty"):
@@ -160,12 +161,12 @@ def copyright_line(line):
def check_asf_copyright(fname):
if fname.endswith(".png"):
return True
- if not os.path.isfile(fname):
+ if not Path(fname).is_file():
return True
has_asf_header = False
has_copyright = False
try:
- for line in open(fname):
+ for line in Path(fname).open():
if line.find("Licensed to the Apache Software Foundation") != -1:
has_asf_header = True
if copyright_line(line):
@@ -193,7 +194,7 @@ def main():
if error_list:
report = "------File type check report----\n"
report += "\n".join(error_list)
- report += "\nFound %d files that are not allowed\n" % len(error_list)
+ report += f"\nFound {len(error_list)} files that are not allowed\n"
report += (
"We do not check in binary files into the repo.\n"
"If necessary, please discuss with committers and"
@@ -212,14 +213,9 @@ def main():
if asf_copyright_list:
report = "------File type check report----\n"
report += "\n".join(asf_copyright_list) + "\n"
- report += (
- "------Found %d files that has ASF header with copyright
message----\n"
- % len(asf_copyright_list)
- )
+ report += f"------Found {len(asf_copyright_list)} files that has ASF
header with copyright message----\n"
report += "--- Files with ASF header do not need Copyright lines.\n"
- report += (
- "--- Contributors retain copyright to their contribution by
default.\n"
- )
+ report += "--- Contributors retain copyright to their contribution by
default.\n"
report += "--- If a file comes with a different license, consider put
it under the 3rdparty folder instead.\n"
report += "---\n"
report += "--- You can use the following steps to remove the copyright
lines\n"
diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh
deleted file mode 100755
index fee4803..0000000
--- a/tests/lint/git-clang-format.sh
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env bash
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-set -e
-set -u
-set -o pipefail
-
-INPLACE_FORMAT=${INPLACE_FORMAT:=false}
-LINT_ALL_FILES=true
-REVISION=$(git rev-list --max-parents=0 HEAD)
-
-while (($#)); do
- case "$1" in
- -i)
- INPLACE_FORMAT=true
- shift 1
- ;;
- --rev)
- LINT_ALL_FILES=false
- REVISION=$2
- shift 2
- ;;
- *)
- echo "Usage: tests/lint/git-clang-format.sh [-i] [--rev
<commit>]"
- echo ""
- echo "Run clang-format on files that changed since <commit> or
on all files in the repo"
- echo "Examples:"
- echo "- Compare last one commit: tests/lint/git-clang-format.sh
--rev HEAD~1"
- echo "- Compare against upstream/main:
tests/lint/git-clang-format.sh --rev upstream/main"
- echo "The -i will use black to format files in-place instead of
checking them."
- exit 1
- ;;
- esac
-done
-
-cleanup() {
- if [ -f /tmp/$$.clang-format.txt ]; then
- echo ""
- echo "---------clang-format log----------"
- cat /tmp/$$.clang-format.txt
- fi
- rm -rf /tmp/$$.clang-format.txt
-}
-trap cleanup 0
-
-CLANG_FORMAT=clang-format-15
-
-if [ -x "$(command -v clang-format-15)" ]; then
- CLANG_FORMAT=clang-format-15
-elif [ -x "$(command -v clang-format)" ]; then
- echo "clang-format might be different from clang-format-15, expect
potential difference."
- CLANG_FORMAT=clang-format
-else
- echo "Cannot find clang-format-15"
- exit 1
-fi
-
-# Print out specific version
-${CLANG_FORMAT} --version
-
-if [[ "$INPLACE_FORMAT" == "true" ]]; then
- echo "Running inplace git-clang-format against $REVISION"
- git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT}
"$REVISION"
- exit 0
-fi
-
-if [[ "$LINT_ALL_FILES" == "true" ]]; then
- echo "Running git-clang-format against all C++ files"
- git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu
--binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt
-else
- echo "Running git-clang-format against $REVISION"
- git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu
--binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt
-fi
-
-if grep --quiet -E "diff" </tmp/$$.clang-format.txt; then
- echo "clang-format lint error found. Consider running clang-format-15
on these files to fix them."
- exit 1
-fi
diff --git a/tests/python/test_access_path.py b/tests/python/test_access_path.py
index f70266b..d3f59fb 100644
--- a/tests/python/test_access_path.py
+++ b/tests/python/test_access_path.py
@@ -94,25 +94,16 @@ def test_path_is_prefix_of():
assert not
AccessPath.root().attr("bar").is_prefix_of(AccessPath.root().attr("foo"))
# Shorter path is prefix of longer path with same start
- assert (
- AccessPath.root()
- .attr("foo")
- .is_prefix_of(AccessPath.root().attr("foo").array_item(2))
- )
+ assert
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo").array_item(2))
# Longer path is not prefix of shorter path
assert (
- not AccessPath.root()
- .attr("foo")
- .array_item(2)
- .is_prefix_of(AccessPath.root().attr("foo"))
+ not
AccessPath.root().attr("foo").array_item(2).is_prefix_of(AccessPath.root().attr("foo"))
)
# Different paths are not prefixes
assert (
- not AccessPath.root()
- .attr("foo")
- .is_prefix_of(AccessPath.root().attr("bar").array_item(2))
+ not
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("bar").array_item(2))
)
@@ -133,16 +124,10 @@ def test_path_equal():
assert not (AccessPath.root().attr("bar") == AccessPath.root().attr("foo"))
# Shorter path does not equal longer path
- assert not (
- AccessPath.root().attr("foo") ==
AccessPath.root().attr("foo").array_item(2)
- )
+ assert not (AccessPath.root().attr("foo") ==
AccessPath.root().attr("foo").array_item(2))
# Longer path does not equal shorter path
- assert not (
- AccessPath.root().attr("foo").array_item(2) ==
AccessPath.root().attr("foo")
- )
+ assert not (AccessPath.root().attr("foo").array_item(2) ==
AccessPath.root().attr("foo"))
# Different paths are not equal
- assert not (
- AccessPath.root().attr("foo") ==
AccessPath.root().attr("bar").array_item(2)
- )
+ assert not (AccessPath.root().attr("foo") ==
AccessPath.root().attr("bar").array_item(2))
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index ae7e52c..9300bc7 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -17,7 +17,6 @@
import pickle
import pytest
-
import tvm_ffi
@@ -31,7 +30,7 @@ def test_array():
def test_bad_constructor_init_state():
- """Test when error is raised before __init_handle_by_constructor
+ """Test when error is raised before __init_handle_by_constructor.
This case we need the FFI binding to gracefully handle both repr
and dealloc by ensuring the chandle is initialized and there is
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 41a7985..1ee735c 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -18,7 +18,6 @@
import pickle
import pytest
-
import tvm_ffi
from tvm_ffi import DLDeviceType
@@ -71,9 +70,7 @@ def test_device_with_dev_id(dev_type, dev_id,
expected_device_type, expect_devic
assert dev.index == expect_device_id
[email protected](
- "dev_type, dev_id", [("cpu:0:0", None), ("cpu:?", None), ("cpu:", None)]
-)
[email protected]("dev_type, dev_id", [("cpu:0:0", None), ("cpu:?",
None), ("cpu:", None)])
def test_deive_type_error(dev_type, dev_id):
with pytest.raises(ValueError):
tvm_ffi.device(dev_type, dev_id)
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 39694d1..9230ccc 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -19,7 +19,6 @@ import pickle
import numpy as np
import pytest
-
import tvm_ffi
diff --git a/tests/python/test_error.py b/tests/python/test_error.py
index 0b3b96a..7b757ad 100644
--- a/tests/python/test_error.py
+++ b/tests/python/test_error.py
@@ -17,7 +17,6 @@
import pytest
-
import tvm_ffi
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index c5f4428..43d9e1f 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -20,7 +20,6 @@ import gc
import sys
import numpy as np
-
import tvm_ffi
@@ -213,7 +212,7 @@ def test_echo_with_opaque_object():
assert sys.getrefcount(x) == 3
def py_callback(z):
- """python callback with opaque object"""
+ """Python callback with opaque object."""
assert z is x
return z
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 0d40d17..2da92e8 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -17,7 +17,6 @@
import sys
import pytest
-
import tvm_ffi
@@ -38,9 +37,7 @@ def test_method():
def test_setter():
# test setter
- obj0 = tvm_ffi.testing.create_object(
- "testing.TestObjectBase", v_i64=10, v_str="hello"
- )
+ obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10,
v_str="hello")
assert obj0.v_i64 == 10
obj0.v_i64 = 11
assert obj0.v_i64 == 11
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index c7b81a8..34e5ccc 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -16,7 +16,6 @@
# under the License.
import pytest
-
import tvm_ffi
import tvm_ffi.cpp
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 5d919a3..13cb76b 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -22,7 +22,6 @@ except ImportError:
torch = None
import numpy as np
-
import tvm_ffi
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index 96f43d3..bef5284 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -14,13 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-This script is used to benchmark the API overhead of different
-python FFI API calling overhead, through DLPack API.
+"""Benchmark API overhead of different python FFI API calling overhead through
DLPack API.
-Specifically, we would like to understand the overall overhead
-python/C++ API calls. The general goal is to understand the overall
-space and get a sense of what are the possible operations.
+Specifically, we would like to understand the overall overhead python/C++ API
calls.
+The general goal is to understand the overall space and get a sense of what
are the possible operations.
We pick function f(x, y, z) where x, y, z are length 1 tensors.
The benchmark is running in eager mode so we can see what is possible.
@@ -29,19 +26,14 @@ eliminate these overheads completely. So the goal is to get
a sense
of what is possible under eager mode.
Summary of some takeaways:
-- numpy.add roughly takes 0.36 us per call, which gives roughly what can
- be done in python env.
-- torch.add on gpu takes about 3.7us per call, giving us an idea of what
- roughly we need to get to in eager mode.
--
-
+- numpy.add roughly takes 0.36 us per call, which gives roughly what can be
done in python env.
+- torch.add on gpu takes about 3.7us per call, giving us an idea of what
roughly we need to get to in eager mode.
"""
import time
import numpy as np
import torch
-
import tvm_ffi
@@ -54,7 +46,7 @@ def print_error(name, error):
def baseline_torch_add(repeat):
- """Run torch.add with one element"""
+ """Run torch.add with one element."""
def run_bench(device):
x = torch.arange(1, device=device)
@@ -78,7 +70,7 @@ def baseline_torch_add(repeat):
def baseline_numpy_add(repeat):
- """Run numpy.add with one element"""
+ """Run numpy.add with one element."""
x = np.arange(1)
y = np.arange(1)
z = np.arange(1)
@@ -93,9 +85,9 @@ def baseline_numpy_add(repeat):
def baseline_cupy_add(repeat):
- """Run cupy.add with one element"""
+ """Run cupy.add with one element."""
try:
- import cupy
+ import cupy # noqa: PLC0415
except ImportError:
# skip if cupy is not installed
return
@@ -130,7 +122,7 @@ def tvm_ffi_nop(repeat):
def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
- """run dlpack conversion + tvm_ffi.nop
+ """Run dlpack conversion + tvm_ffi.nop.
Measures overhead of running dlpack for each args then invoke
"""
@@ -151,7 +143,7 @@ def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
def tvm_ffi_nop_from_torch_dlpack(repeat):
- """run dlpack conversion + tvm_ffi.nop
+ """Run dlpack conversion + tvm_ffi.nop.
Measures overhead of running dlpack for each args then invoke
"""
@@ -162,7 +154,7 @@ def tvm_ffi_nop_from_torch_dlpack(repeat):
def tvm_ffi_nop_from_numpy_dlpack(repeat):
- """run dlpack conversion + tvm_ffi.nop
+ """Run dlpack conversion + tvm_ffi.nop.
Measures overhead of running dlpack for each args then invoke
"""
@@ -173,7 +165,7 @@ def tvm_ffi_nop_from_numpy_dlpack(repeat):
def tvm_ffi_self_dlpack_nop(repeat):
- """run dlpack conversion + tvm_ffi.nop
+ """Run dlpack conversion + tvm_ffi.nop.
Measures overhead of running dlpack for each args then invoke
"""
@@ -184,9 +176,8 @@ def tvm_ffi_self_dlpack_nop(repeat):
def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
- """
- Measures overhead of running dlpack for each args then invoke
- but uses the legacy torch.utils.dlpack.to_dlpack API
+ """Measures overhead of running dlpack for each args then invoke
+ but uses the legacy torch.utils.dlpack.to_dlpack API.
This helps to measure possible implementation overhead of torch.
"""
@@ -212,8 +203,7 @@ def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
- """
- Measures overhead of running dlpack via auto convert by directly
+ """Measures overhead of running dlpack via auto convert by directly
take torch.Tensor as inputs.
"""
nop = tvm_ffi.get_global_func("testing.nop")
@@ -227,8 +217,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False):
- """
- Measures overhead of running dlpack via auto convert by directly
+ """Measures overhead of running dlpack via auto convert by directly
take torch.Tensor as inputs.
"""
# use larger to ensure alignment req is met
@@ -241,14 +230,11 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat,
device="cpu", stream=False):
f"tvm_ffi.nop.autodlpack(torch[{device}][stream])", x, y, z,
repeat
)
else:
- bench_tvm_ffi_nop_autodlpack(
- f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat
- )
+
bench_tvm_ffi_nop_autodlpack(f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y,
z, repeat)
def tvm_ffi_nop_autodlpack_from_numpy(repeat):
- """
- Measures overhead of running dlpack via auto convert by directly
+ """Measures overhead of running dlpack via auto convert by directly
take numpy.ndarray as inputs.
"""
# use larger to ensure alignment req is met
@@ -259,8 +245,7 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat):
def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
- """
- Measures overhead of running dlpack via auto convert by directly
+ """Measures overhead of running dlpack via auto convert by directly
take test wrapper as inputs. This effectively measure DLPack exchange in
tvm ffi.
"""
x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
@@ -285,9 +270,7 @@ def bench_to_dlpack(x, name, repeat):
def bench_to_dlpack_versioned(x, name, repeat, max_version=(1, 1)):
- """
- Measures overhead of running dlpack with latest 1.1.
- """
+ """Measures overhead of running dlpack with latest 1.1."""
try:
x.__dlpack__(max_version=max_version)
start = time.time()
@@ -301,9 +284,7 @@ def bench_to_dlpack_versioned(x, name, repeat,
max_version=(1, 1)):
def bench_torch_utils_to_dlpack(repeat):
- """
- Measures overhead of running torch.utils.dlpack.to_dlpack
- """
+ """Measures overhead of running torch.utils.dlpack.to_dlpack."""
x = torch.arange(1)
torch.utils.dlpack.to_dlpack(x)
start = time.time()
@@ -320,7 +301,7 @@ def torch_get_cuda_stream_native(device_id):
def load_torch_get_current_cuda_stream():
"""Create a faster get_current_cuda_stream for torch through cpp
extension."""
- from torch.utils import cpp_extension
+ from torch.utils import cpp_extension # noqa: PLC0415
source = """
#include <c10/cuda/CUDAStream.h>
@@ -345,9 +326,7 @@ def load_torch_get_current_cuda_stream():
def bench_torch_get_current_stream(repeat, name, func):
- """
- Measures overhead of running torch.cuda.current_stream
- """
+ """Measures overhead of running torch.cuda.current_stream."""
x = torch.arange(1, device="cuda") # noqa: F841
func(0)
start = time.time()
@@ -360,14 +339,12 @@ def bench_torch_get_current_stream(repeat, name, func):
def populate_object_table(num_classes):
nop = tvm_ffi.get_global_func("testing.nop")
- dummy_instances = [
- type(f"DummyClass{i}", (object,), {})() for i in range(num_classes)
- ]
+ dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in
range(num_classes)]
for instance in dummy_instances:
nop(instance)
-def main():
+def main(): # noqa: PLR0915
repeat = 10000
# measures impact of object dispatch table size
# takeaway so far is that there is no impact on the performance
@@ -401,12 +378,8 @@ def main():
print("---------------------------------------------------")
print("Benchmark x.__dlpack__(max_version=(1,1)) overhead")
print("---------------------------------------------------")
- bench_to_dlpack_versioned(
- torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat
- )
- bench_to_dlpack_versioned(
- np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat
- )
+ bench_to_dlpack_versioned(torch.arange(1),
"torch.__dlpack__(max_version=(1,1))", repeat)
+ bench_to_dlpack_versioned(np.arange(1),
"numpy.__dlpack__(max_version=(1,1))", repeat)
bench_to_dlpack_versioned(
tvm_ffi.from_dlpack(torch.arange(1)),
"tvm.__dlpack__(max_version=(1,1))",
@@ -415,9 +388,7 @@ def main():
print("---------------------------------------------------")
print("Benchmark torch.get_cuda_stream[default stream]")
print("---------------------------------------------------")
- bench_torch_get_current_stream(
- repeat, "cpp-extension", load_torch_get_current_cuda_stream()
- )
+ bench_torch_get_current_stream(repeat, "cpp-extension",
load_torch_get_current_cuda_stream())
bench_torch_get_current_stream(repeat, "python",
torch_get_cuda_stream_native)
print("---------------------------------------------------")
print("Benchmark torch.get_cuda_stream[non-default stream]")
diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh
deleted file mode 100755
index 5b17cf8..0000000
--- a/tests/scripts/task_lint.sh
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env bash
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-set -euxo pipefail
-
-cleanup() {
- rm -rf /tmp/$$.*
-}
-trap cleanup 0
-
-function run_lint {
- echo "Checking file types..."
- python tests/lint/check_file_type.py
-
- echo "Checking ASF headers..."
- python tests/lint/check_asf_header.py --check
-
- echo "isort check..."
- isort --check --diff .
-
- echo "black check..."
- black --check --diff .
-
- echo "ruff check..."
- ruff check --diff .
-
- echo "clang-format check..."
- tests/lint/git-clang-format.sh
-}
-
-run_lint