This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 64e4b7f chore: Modernize pre-commit hooks (#19)
64e4b7f is described below
commit 64e4b7f01896e3b755a49f5f4f7c25329aacc0b7
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 17 14:11:17 2025 -0700
chore: Modernize pre-commit hooks (#19)
This PR makes the following changes:
- Introduces pre-commit hooks that includes:
- check-yaml
- check-toml
- ruff-check and ruff-format
- clang-format
- cython-lint
- shell-format
- shell-check
- ASF header checks (migrated from `task_lint.sh`)
- Filetype checks (migrated from `task_lint.sh`)
- Removes `tests/scripts/task_lint.sh` and unnecessary/conflicting
checks, including:
- isort (part of `ruff-format`)
- black (part of `ruff-format`)
- Updates `lint` stage in `ci_test.yml`, replacing `task_lint.sh` with
native pre-commit hooks
- Makes stage `test` to additionally depend on `lint`
---
.github/workflows/ci_test.yml | 21 +----
.pre-commit-config.yaml | 51 +++++++++++-
README.md | 2 -
docs/_static/custom.css | 2 +-
docs/conf.py | 7 +-
docs/get_started/quick_start.md | 1 -
docs/guides/compiler_integration.md | 23 +++---
docs/guides/python_guide.md | 2 +-
docs/requirements.txt | 2 +-
.../packaging/python/my_ffi_extension/_ffi_api.py | 1 -
examples/quick_start/run_example.py | 1 -
examples/quick_start/run_example.sh | 2 +-
include/tvm/ffi/c_api.h | 20 ++---
include/tvm/ffi/reflection/registry.h | 8 +-
python/tvm_ffi/__init__.py | 1 +
python/tvm_ffi/_convert.py | 1 +
python/tvm_ffi/_dtype.py | 6 +-
python/tvm_ffi/_ffi_api.py | 1 +
python/tvm_ffi/_optional_torch_c_dlpack.py | 1 +
python/tvm_ffi/base.py | 1 +
python/tvm_ffi/config.py | 24 ++++--
python/tvm_ffi/container.py | 7 +-
python/tvm_ffi/cpp/load_inline.py | 47 ++++++++---
python/tvm_ffi/cython/base.pxi | 26 +++----
python/tvm_ffi/cython/device.pxi | 2 +-
python/tvm_ffi/cython/dtype.pxi | 1 +
python/tvm_ffi/cython/function.pxi | 26 +++----
python/tvm_ffi/cython/object.pxi | 1 +
python/tvm_ffi/cython/string.pxi | 1 -
python/tvm_ffi/cython/tensor.pxi | 3 +-
python/tvm_ffi/error.py | 1 +
python/tvm_ffi/libinfo.py | 19 ++++-
python/tvm_ffi/module.py | 4 +-
python/tvm_ffi/registry.py | 1 +
python/tvm_ffi/stream.py | 1 +
python/tvm_ffi/utils/lockfile.py | 4 +-
tests/lint/check_asf_header.py | 9 ++-
tests/lint/check_file_type.py | 12 ++-
tests/lint/git-clang-format.sh | 91 +++++++++++-----------
tests/python/test_access_path.py | 27 +++++--
tests/python/test_device.py | 4 +-
tests/python/test_object.py | 4 +-
tests/scripts/benchmark_dlpack.py | 49 +++++-------
tests/scripts/task_cpp_tests.sh | 2 +-
tests/scripts/task_lint.sh | 29 ++++---
45 files changed, 329 insertions(+), 220 deletions(-)
diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml
index 1bed651..641d3a8 100644
--- a/.github/workflows/ci_test.yml
+++ b/.github/workflows/ci_test.yml
@@ -47,30 +47,13 @@ jobs:
lint:
needs: [prepare]
- if: >
- needs.prepare.outputs.should_skip_ci_commit != 'true' &&
- needs.prepare.outputs.should_skip_ci_docs_only != 'true'
- name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- with:
- submodules: recursive
- - uses: actions/setup-python@v6
- with:
- python-version: '3.13'
- - name: Install dependencies
- run: |
- pip install black pylint ruff
- sudo apt-get install -y clang-format-15
-
- - name: Lint
- run: |
- tests/scripts/task_lint.sh
-
+ - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd #
v3.0.1
test:
- needs: [prepare]
+ needs: [lint, prepare]
if: >
needs.prepare.outputs.should_skip_ci_commit != 'true' &&
needs.prepare.outputs.should_skip_ci_docs_only != 'true'
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index dc24845..f31f7f3 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,8 +15,29 @@
# specific language governing permissions and limitations
# under the License.
+# TODO(@junrushao): adding a few extra hooks:
+# - Python type checking via mypy or ty
+# - CMake linters
+# - Conventional commits
+default_install_hook_types:
+ - pre-commit
repos:
- # Standard hooks
+ - repo: local
+ hooks:
+ - id: check-asf-header
+ name: check ASF Header
+ entry: python tests/lint/check_asf_header.py --check
+ language: system
+ pass_filenames: false
+ verbose: false
+ - repo: local
+ hooks:
+ - id: check-file-type
+ name: check file types
+ entry: python tests/lint/check_file_type.py
+ language: system
+ pass_filenames: false
+ verbose: false
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
@@ -28,3 +49,31 @@ repos:
- id: mixed-line-ending
- id: requirements-txt-fixer
- id: trailing-whitespace
+ - id: check-yaml
+ - id: check-toml
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.12.3
+ hooks:
+ - id: ruff-check
+ types_or: [python, pyi, jupyter]
+ args: [--fix]
+ - id: ruff-format
+ types_or: [python, pyi, jupyter]
+ - repo: https://github.com/pre-commit/mirrors-clang-format
+ rev: "v20.1.8"
+ hooks:
+ - id: clang-format
+ - repo: https://github.com/MarcoGorelli/cython-lint
+ rev: v0.16.7
+ hooks:
+ - id: cython-lint
+ args: [--max-line-length=120]
+ - id: double-quote-cython-strings
+ - repo: https://github.com/scop/pre-commit-shfmt
+ rev: v3.12.0-2
+ hooks:
+ - id: shfmt
+ - repo: https://github.com/shellcheck-py/shellcheck-py
+ rev: v0.10.0.1
+ hooks:
+ - id: shellcheck
diff --git a/README.md b/README.md
index 525518e..88083e9 100644
--- a/README.md
+++ b/README.md
@@ -18,5 +18,3 @@
# tvm ffi
[](https://github.com/apache/tvm-ffi/actions/workflows/ci_test.yml)
-
-
diff --git a/docs/_static/custom.css b/docs/_static/custom.css
index 6277c6d..088d499 100644
--- a/docs/_static/custom.css
+++ b/docs/_static/custom.css
@@ -2,4 +2,4 @@
See: https://github.com/executablebooks/sphinx-book-theme/issues/732 */
#rtd-footer-container {
margin: 0px !important;
-}
\ No newline at end of file
+}
diff --git a/docs/conf.py b/docs/conf.py
index 8a1c5b3..e621878 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -16,7 +16,6 @@
# under the License.
# -*- coding: utf-8 -*-
import os
-import sys
import tomli
@@ -198,7 +197,7 @@ def footer_html():
<div class="dropdown">
<button class="btn btn-link dropdown-toggle" type="button"
id="footerDropdown" data-bs-toggle="dropdown"
aria-expanded="false" style="font-size: 0.9em; color:
#6c757d; text-decoration: none; padding: 0; border: none; background: none;">
- {footer_dropdown['name']}
+ {footer_dropdown["name"]}
</button>
<ul class="dropdown-menu" aria-labelledby="footerDropdown"
style="font-size: 0.9em;">
{dropdown_items} </ul>
@@ -226,6 +225,6 @@ html_context = {
"conf_py_path": "/docs/",
}
-html_static_path = ['_static']
+html_static_path = ["_static"]
-html_css_files = ['custom.css']
+html_css_files = ["custom.css"]
diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md
index 449c1db..043ea28 100644
--- a/docs/get_started/quick_start.md
+++ b/docs/get_started/quick_start.md
@@ -279,4 +279,3 @@ The main takeaway points are:
- **ffi::Tensor** is a universal tensor structure that enables zero-copy
exchange of array data
- **Module loading** is provided by tvm ffi APIs in multiple languages.
- **C ABI** is provided for easy low-level integration
-
diff --git a/docs/guides/compiler_integration.md
b/docs/guides/compiler_integration.md
index a1355af..7338dbf 100644
--- a/docs/guides/compiler_integration.md
+++ b/docs/guides/compiler_integration.md
@@ -112,14 +112,14 @@ with various kernel DSLs and libraries.
## Runtime and State Management for Compilers
-While TVM FFI provides a standard ABI for compiler-generated kernels, many
compilers and domain-specific languages
-(DSLs) require their own **runtime** to manage states like dynamic shapes,
workspace memory, or other
-application-specific data. This runtime can be a separate shared library
accessible to all kernels from a specific
+While TVM FFI provides a standard ABI for compiler-generated kernels, many
compilers and domain-specific languages
+(DSLs) require their own **runtime** to manage states like dynamic shapes,
workspace memory, or other
+application-specific data. This runtime can be a separate shared library
accessible to all kernels from a specific
compiler.
### Recommended Approach for State Management
-The recommended approach for managing compiler-specific state is to define the
state within a **separate shared library**.
+The recommended approach for managing compiler-specific state is to define the
state within a **separate shared library**.
This library exposes its functionality by registering functions as global
`tvm::ffi::Function`s.
Here's a breakdown of the process:
@@ -144,21 +144,21 @@ Here's a breakdown of the process:
This method allows both C++ and Python to access the runtime state through
a consistent API.
3. **Access State from Kernels**: Within your compiler-generated kernels, you
can use
`GetGlobalRequired("mylang.get_global_state")` in C++ or the C equivalent
- `TVMFFIGetGlobalFunction("mylang.get_global_state", ...)` to get the
function and then call it to retrieve the state
+ `TVMFFIGetGlobalFunction("mylang.get_global_state", ...)` to get the
function and then call it to retrieve the state
pointer.
### Distributing the Runtime
-For a user to use a kernel from your compiler, they must have access to your
runtime library. The preferred method is to
-package the runtime shared library (e.g., `libmylang_runtime.so`) as part of a
Python or C++ package. Users must install
-and import this package before loading any kernels compiled by your system.
+For a user to use a kernel from your compiler, they must have access to your
runtime library. The preferred method is to
+package the runtime shared library (e.g., `libmylang_runtime.so`) as part of a
Python or C++ package. Users must install
+and import this package before loading any kernels compiled by your system.
This approach ensures the state is shared among different kernels.
### Common vs. Custom State
-It's important to distinguish between compiler-specific state and **common
state** managed by TVM FFI. TVM FFI handles
-common states like **streams** and **memory allocators** through environment
functions (e.g., `TVMFFIEnvGetStream`),
-allowing kernels to access these without managing their own. However, for any
unique state required by your compiler,
+It's important to distinguish between compiler-specific state and **common
state** managed by TVM FFI. TVM FFI handles
+common states like **streams** and **memory allocators** through environment
functions (e.g., `TVMFFIEnvGetStream`),
+allowing kernels to access these without managing their own. However, for any
unique state required by your compiler,
the global function registration approach is the most suitable method.
## Advanced: Custom Modules
@@ -196,4 +196,3 @@ the overall import relations from `<import_tree>` and
return the final composed
As long as the compiler generates the `__tvm_ffi__library_bin` in the above
format, {py:func}`tvm_ffi.load_module` will correctly
handle the loading and recover the original module. Note that we will need the
custom module class definition to be available
during loading, either by importing another runtime DLL, or embedding it in
the generated library.
-
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index fdf03a5..cd997af 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -178,7 +178,7 @@ torch.testing.assert_close(x + 1, y)
```
The above code defines a C++ function `add_one_cpu` in Python script, compiles
it on the fly and then loads the compiled
-{py:class}`tvm_ffi.Module` object via {py:func}`tvm_ffi.cpp.load_inline`. You
can then call the function `add_one_cpu`
+{py:class}`tvm_ffi.Module` object via {py:func}`tvm_ffi.cpp.load_inline`. You
can then call the function `add_one_cpu`
from the module as usual.
## Error Handling
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 74784b5..55a8565 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,6 +1,6 @@
autodocsumm
-exhale
breathe
+exhale
linkify-it-py
matplotlib
myst-parser
diff --git a/examples/packaging/python/my_ffi_extension/_ffi_api.py
b/examples/packaging/python/my_ffi_extension/_ffi_api.py
index 616b1ee..5e03489 100644
--- a/examples/packaging/python/my_ffi_extension/_ffi_api.py
+++ b/examples/packaging/python/my_ffi_extension/_ffi_api.py
@@ -17,7 +17,6 @@
import tvm_ffi
# make sure lib is loaded first
-from .base import _LIB
# this is a short cut to register all the global functions
# prefixed by `my_ffi_extension.` to this module
diff --git a/examples/quick_start/run_example.py
b/examples/quick_start/run_example.py
index e126af1..698bc2a 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -21,7 +21,6 @@ try:
except ImportError:
torch = None
-import ctypes
import numpy
diff --git a/examples/quick_start/run_example.sh
b/examples/quick_start/run_example.sh
index 0602b85..09d8daa 100755
--- a/examples/quick_start/run_example.sh
+++ b/examples/quick_start/run_example.sh
@@ -1,3 +1,4 @@
+#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#!/bin/bash
set -ex
cmake -B build -S .
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 3dcdf4f..0ab4d08 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -272,16 +272,16 @@ typedef struct {
*/
uint32_t small_str_len;
};
- union { // 8 bytes
- int64_t v_int64; // integers
- double v_float64; // floating-point numbers
- void* v_ptr; // typeless pointers
- const char* v_c_str; // raw C-string
- TVMFFIObject* v_obj; // ref counted objects
- DLDataType v_dtype; // data type
- DLDevice v_device; // device
- char v_bytes[8]; // small string
- uint64_t v_uint64; // uint64 repr mainly used for hashing
+ union { // 8 bytes
+ int64_t v_int64; // integers
+ double v_float64; // floating-point numbers
+ void* v_ptr; // typeless pointers
+ const char* v_c_str; // raw C-string
+ TVMFFIObject* v_obj; // ref counted objects
+ DLDataType v_dtype; // data type
+ DLDevice v_device; // device
+ char v_bytes[8]; // small string
+ uint64_t v_uint64; // uint64 repr mainly used for hashing
};
} TVMFFIAny;
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 6a1a9b5..c0d984f 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -113,7 +113,7 @@ class AttachFieldFlag : public FieldInfoTrait {
* \returns The byteoffset
*/
template <typename Class, typename T>
-TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
+TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
int64_t field_offset_to_class =
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
return field_offset_to_class -
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
@@ -350,7 +350,7 @@ class ObjectDef : public ReflectionDefBase {
* \return The reflection definition.
*/
template <typename T, typename BaseClass, typename... Extra>
- TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr,
Extra&&... extra) {
+ TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::* field_ptr,
Extra&&... extra) {
RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
return *this;
}
@@ -369,7 +369,7 @@ class ObjectDef : public ReflectionDefBase {
* \return The reflection definition.
*/
template <typename T, typename BaseClass, typename... Extra>
- TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr,
Extra&&... extra) {
+ TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::* field_ptr,
Extra&&... extra) {
static_assert(Class::_type_mutable, "Only mutable classes are supported
for writable fields");
RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
return *this;
@@ -430,7 +430,7 @@ class ObjectDef : public ReflectionDefBase {
}
template <typename T, typename BaseClass, typename... ExtraArgs>
- void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable,
+ void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
ExtraArgs&&... extra_args) {
static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a
base class of Class");
TVMFFIFieldInfo info;
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 9bafe2b..b3b070f 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""TVM FFI Python package."""
+
# order matters here so we need to skip isort here
# isort: skip_file
# base always go first to load the libtvm_ffi
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 7c7b515..cf311b2 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Conversion utilities to bring python objects into ffi values."""
+
from numbers import Number
from typing import Any
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 30409e4..1664d98 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""dtype class."""
+
# pylint: disable=invalid-name
from enum import IntEnum
@@ -83,7 +84,10 @@ class dtype(str):
The new dtype with the given number of lanes.
"""
cdtype = core._create_dtype_from_tuple(
- core.DataType, self.__tvm_ffi_dtype__.type_code,
self.__tvm_ffi_dtype__.bits, lanes
+ core.DataType,
+ self.__tvm_ffi_dtype__.type_code,
+ self.__tvm_ffi_dtype__.bits,
+ lanes,
)
val = str.__new__(dtype, str(cdtype))
val.__tvm_ffi_dtype__ = cdtype
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 1c2326c..f9314dd 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""FFI API."""
+
from . import registry
registry.init_ffi_api("ffi", __name__)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 6ff77cc..b96e9d0 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -29,6 +29,7 @@ likely be phased away and deleted after changes landed and
released in pytorch.
This module will load slowly at first time due to JITing,
subsequent calls will be much faster.
"""
+
import warnings
from . import libinfo
diff --git a/python/tvm_ffi/base.py b/python/tvm_ffi/base.py
index fbcc01c..8099955 100644
--- a/python/tvm_ffi/base.py
+++ b/python/tvm_ffi/base.py
@@ -16,6 +16,7 @@
# under the License.
# coding: utf-8
"""Base library for TVM FFI."""
+
import ctypes
import logging
import os
diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py
index 4e87caa..7e03680 100644
--- a/python/tvm_ffi/config.py
+++ b/python/tvm_ffi/config.py
@@ -37,16 +37,28 @@ def __main__():
description="Get various configuration information needed to compile
with tvm-ffi"
)
- parser.add_argument("--includedir", action="store_true", help="Print
include directory")
parser.add_argument(
- "--dlpack-includedir", action="store_true", help="Print dlpack include
directory"
+ "--includedir", action="store_true", help="Print include directory"
+ )
+ parser.add_argument(
+ "--dlpack-includedir",
+ action="store_true",
+ help="Print dlpack include directory",
+ )
+ parser.add_argument(
+ "--cmakedir", action="store_true", help="Print library directory"
+ )
+ parser.add_argument(
+ "--sourcedir", action="store_true", help="Print source directory"
+ )
+ parser.add_argument(
+ "--libfiles", action="store_true", help="Fully qualified library
filenames"
)
- parser.add_argument("--cmakedir", action="store_true", help="Print library
directory")
- parser.add_argument("--sourcedir", action="store_true", help="Print source
directory")
- parser.add_argument("--libfiles", action="store_true", help="Fully
qualified library filenames")
parser.add_argument("--libdir", action="store_true", help="Print library
directory")
parser.add_argument("--libs", action="store_true", help="Libraries to be
linked")
- parser.add_argument("--cython-lib-path", action="store_true", help="Print
cython path")
+ parser.add_argument(
+ "--cython-lib-path", action="store_true", help="Print cython path"
+ )
parser.add_argument("--cxxflags", action="store_true", help="Print cxx
flags")
parser.add_argument("--cflags", action="store_true", help="Print c flags")
parser.add_argument("--ldflags", action="store_true", help="Print ld
flags")
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index f64028f..8368cd4 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Container classes."""
+
import collections.abc
from typing import Any, Mapping, Sequence
@@ -248,4 +249,8 @@ class Map(core.Object, collections.abc.Mapping):
# exception safety handling for chandle=None
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
- return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in
self.items()]) + "}"
+ return (
+ "{"
+ + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in
self.items()])
+ + "}"
+ )
diff --git a/python/tvm_ffi/cpp/load_inline.py
b/python/tvm_ffi/cpp/load_inline.py
index ced9705..6ce3d11 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -86,7 +86,9 @@ def _find_cuda_home() -> Optional[str]:
else:
# Guess #3
if IS_WINDOWS:
- cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing
Toolkit/CUDA/v*.*")
+ cuda_homes = glob.glob(
+ "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
+ )
if len(cuda_homes) == 0:
cuda_home = ""
else:
@@ -162,7 +164,9 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
raise FileNotFoundError("No Visual Studio installation found.")
# Construct the path to the VsDevCmd.bat file
- vsdevcmd_path = os.path.join(vs_install_path, "Common7", "Tools",
"VsDevCmd.bat")
+ vsdevcmd_path = os.path.join(
+ vs_install_path, "Common7", "Tools", "VsDevCmd.bat"
+ )
if not os.path.exists(vsdevcmd_path):
raise FileNotFoundError(f"VsDevCmd.bat not found at:
{vsdevcmd_path}")
@@ -175,7 +179,9 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
)
# Execute the command in a new shell
- return subprocess.run(cmd_command, cwd=cwd,
capture_output=capture_output, shell=True)
+ return subprocess.run(
+ cmd_command, cwd=cwd, capture_output=capture_output, shell=True
+ )
except (FileNotFoundError, subprocess.CalledProcessError) as e:
raise RuntimeError(
@@ -217,7 +223,11 @@ def _generate_ninja_build(
"/EHsc",
]
default_cuda_cflags = ["-Xcompiler", "/std:c++17", "/O2"]
- default_ldflags = ["/DLL", f"/LIBPATH:{tvm_ffi_lib_path}",
f"{tvm_ffi_lib_name}.lib"]
+ default_ldflags = [
+ "/DLL",
+ f"/LIBPATH:{tvm_ffi_lib_path}",
+ f"{tvm_ffi_lib_name}.lib",
+ ]
else:
default_cflags = ["-std=c++17", "-fPIC", "-O2"]
default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"]
@@ -226,12 +236,17 @@ def _generate_ninja_build(
if with_cuda:
# determine the compute capability of the current GPU
default_cuda_cflags += [_get_cuda_target()]
- default_ldflags += ["-L{}".format(os.path.join(_find_cuda_home(),
"lib64")), "-lcudart"]
+ default_ldflags += [
+ "-L{}".format(os.path.join(_find_cuda_home(), "lib64")),
+ "-lcudart",
+ ]
cflags = default_cflags + [flag.strip() for flag in extra_cflags]
cuda_cflags = default_cuda_cflags + [flag.strip() for flag in
extra_cuda_cflags]
ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags]
- include_paths = default_include_paths + [os.path.abspath(path) for path in
extra_include_paths]
+ include_paths = default_include_paths + [
+ os.path.abspath(path) for path in extra_include_paths
+ ]
# append include paths
for path in include_paths:
@@ -241,7 +256,9 @@ def _generate_ninja_build(
# flags
ninja = []
ninja.append("ninja_required_version = 1.3")
- ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS
else "c++")))
+ ninja.append(
+ "cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))
+ )
ninja.append("cflags = {}".format(" ".join(cflags)))
if with_cuda:
ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin",
"nvcc")))
@@ -290,7 +307,9 @@ def _generate_ninja_build(
)
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- ninja.append("build {}{}: link main.o{}".format(name, ext, " cuda.o" if
with_cuda else ""))
+ ninja.append(
+ "build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda
else "")
+ )
ninja.append("")
# default target
@@ -306,7 +325,9 @@ def _build_ninja(build_dir: str) -> None:
if num_workers is not None:
command += ["-j", num_workers]
if IS_WINDOWS:
- status = _run_command_in_dev_prompt(args=command, cwd=build_dir,
capture_output=True)
+ status = _run_command_in_dev_prompt(
+ args=command, cwd=build_dir, capture_output=True
+ )
else:
status = subprocess.run(args=command, cwd=build_dir,
capture_output=True)
if status.returncode != 0:
@@ -508,7 +529,9 @@ def load_inline(
extra_ldflags,
extra_include_paths,
)
- build_dir: str = os.path.join(build_directory, "{}_{}".format(name,
source_hash))
+ build_dir: str = os.path.join(
+ build_directory, "{}_{}".format(name, source_hash)
+ )
else:
build_dir = os.path.abspath(build_directory)
os.makedirs(build_dir, exist_ok=True)
@@ -536,4 +559,6 @@ def load_inline(
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- return load_module(os.path.abspath(os.path.join(build_dir,
"{}{}".format(name, ext))))
+ return load_module(
+ os.path.abspath(os.path.join(build_dir, "{}{}".format(name, ext)))
+ )
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 77c9c7e..6fe10fd 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -108,7 +108,6 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74
-
ctypedef void* TVMFFIObjectHandle
ctypedef struct TVMFFIObject:
@@ -153,9 +152,9 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
- ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept;
- ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value)
noexcept;
- ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept;
+ ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
+ ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value)
noexcept
+ ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept
ctypedef struct TVMFFIFieldInfo:
TVMFFIByteArray name
@@ -202,7 +201,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t
num_args,
TVMFFIAny* result) nogil
int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call,
- void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
+ void (*deleter)(void*), TVMFFIObjectHandle* out)
nogil
int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out)
nogil
int TVMFFIFunctionSetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle f,
int override) nogil
int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle*
out) nogil
@@ -216,17 +215,17 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIBytesFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil
- const TVMFFIByteArray* TVMFFITraceback(
- const char* filename, int lineno, const char* func, int
cross_ffi_boundary) nogil;
+ const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno,
+ const char* func, int
cross_ffi_boundary) nogil
int TVMFFITensorFromDLPack(DLManagedTensor* src, int32_t require_alignment,
- int32_t require_contiguous,
TVMFFIObjectHandle* out) nogil
+ int32_t require_contiguous, TVMFFIObjectHandle*
out) nogil
int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* src,
int32_t require_alignment,
int32_t require_contiguous,
TVMFFIObjectHandle* out) nogil
int TVMFFITensorToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out)
nogil
int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle src,
- DLManagedTensorVersioned** out) nogil
+ DLManagedTensorVersioned** out) nogil
const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil
TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny*
value) nogil
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
@@ -241,9 +240,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) nogil
- int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
- TVMFFIStreamHandle stream,
- TVMFFIStreamHandle* opt_out_original_stream)
nogil
+ int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
+ TVMFFIStreamHandle* opt_out_original_stream) nogil
+
def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
cdef TVMFFIStreamHandle prev_stream = NULL
@@ -256,8 +255,7 @@ def _env_set_current_stream(int device_type, int device_id,
uint64_t stream):
cdef extern from "tvm_ffi_python_helpers.h":
- # no need to expose fields of the call context
- # setter data structure
+ # no need to expose fields of the call context setter data structure
ctypedef int (*DLPackFromPyObject)(
void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle*
env_stream
) except -1
diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 85740a0..84c047f 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -20,6 +20,7 @@ from enum import IntEnum
_CLASS_DEVICE = None
+
def _set_class_device(cls):
global _CLASS_DEVICE
_CLASS_DEVICE = cls
@@ -162,7 +163,6 @@ cdef class Device:
def __hash__(self):
return hash((self.cdevice.device_type, self.cdevice.device_id))
-
def __device_type_name__(self):
return self._DEVICE_TYPE_TO_NAME[self.cdevice.device_type]
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 4656b04..0036c80 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -18,6 +18,7 @@
_CLASS_DTYPE = None
+
def _set_class_dtype(cls):
global _CLASS_DTYPE
_CLASS_DTYPE = cls
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index c80e238..c4662ca 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -88,9 +88,9 @@ cdef inline object make_ret(TVMFFIAny result,
DLPackToPyObject c_dlpack_to_pyobj
raise ValueError("Unhandled type index %d" % type_index)
-##----------------------------------------------------------------------------
-## Helper to simplify calling constructor
-##----------------------------------------------------------------------------
+# ----------------------------------------------------------------------------
+# Helper to simplify calling constructor
+# ----------------------------------------------------------------------------
cdef inline int ConstructorCall(void* constructor_handle,
PyObject* py_arg_tuple,
void** handle,
@@ -109,9 +109,9 @@ cdef inline int ConstructorCall(void* constructor_handle,
handle[0] = result.v_ptr
return 0
-##----------------------------------------------------------------------------
-## Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_
-##----------------------------------------------------------------------------
+# ----------------------------------------------------------------------------
+# Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_
+# ----------------------------------------------------------------------------
cdef int TVMFFIPyArgSetterTensor_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
@@ -219,8 +219,7 @@ cdef int TVMFFIPyArgSetterDLPack_(
out.v_ptr = temp_chandle
# record the stream from the source framework context when possible
temp_dltensor = TVMFFITensorGetDLTensorPtr(temp_chandle)
- if (temp_dltensor.device.device_type != kDLCPU and
- ctx.device_type != -1):
+ if (temp_dltensor.device.device_type != kDLCPU and ctx.device_type != -1):
# __tvm_ffi_env_stream__ returns the expected stream that should be set
# through TVMFFIEnvSetStream when calling a TVM FFI function
if hasattr(arg, "__tvm_ffi_env_stream__"):
@@ -571,9 +570,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
out.func = TVMFFIPyArgSetterFallback_
return 0
-#---------------------------------------------------------------------------------------------
-## Implementation of function calling
-#---------------------------------------------------------------------------------------------
+#
---------------------------------------------------------------------------------------------
+# Implementation of function calling
+#
---------------------------------------------------------------------------------------------
cdef class Function(Object):
"""Python class that wraps a function with tvm-ffi ABI.
@@ -591,6 +590,7 @@ cdef class Function(Object):
property release_gil:
def __get__(self):
return self.c_release_gil != 0
+
def __set__(self, value):
self.c_release_gil = value
@@ -747,7 +747,7 @@ def _get_global_func(name, allow_missing):
return ret
if allow_missing:
- return None
+ return None
raise ValueError("Cannot find global function %s" % name)
@@ -835,7 +835,7 @@ def _convert_to_opaque_object(object pyobject):
def _print_debug_info():
"""Get the size of the dispatch map"""
- cdef size_t size = TVMFFIPyGetDispatchMapSize()
+ cdef size_t size = TVMFFIPyGetDispatchMapSize()
print(f"TVMFFIPyGetDispatchMapSize: {size}")
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 90821e3..f47018c 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -250,6 +250,7 @@ def _object_type_key_to_index(str type_key):
return tidx
return None
+
cdef inline str _type_index_to_key(int32_t tindex):
"""get the type key of object class"""
cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex)
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index 4119e7b..0f9d11b 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -28,7 +28,6 @@ cdef inline bytes _bytes_obj_get_py_bytes(obj):
return bytearray_to_bytes(bytes)
-
class String(str, PyNativeObject):
__slots__ = ["__tvm_ffi_object__"]
"""String object that is possibly returned by FFI call.
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 1255f0b..dc8b75e 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -106,7 +106,7 @@ cdef inline int _from_dlpack_universal(
# move to false as most frameworks get upgraded.
cdef int favor_legacy_dlpack = True
- if hasattr(ext_tensor, '__dlpack__'):
+ if hasattr(ext_tensor, "__dlpack__"):
if favor_legacy_dlpack:
_from_dlpack(
ext_tensor.__dlpack__(),
@@ -305,6 +305,7 @@ cdef class DLTensorTestWrapper:
cdef Tensor tensor
cdef dict __dict__
+
def __init__(self, tensor):
self.tensor = tensor
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index b09200d..cec6956 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Error handling."""
+
import ast
import re
import sys
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index 55c5140..8325c35 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -75,7 +75,9 @@ def find_libtvm_ffi():
lib_found = [p for p in lib_dll_path if os.path.exists(p) and
os.path.isfile(p)]
if not lib_found:
- raise RuntimeError(f"Cannot find library: {name}\nList of
candidates:\n{lib_dll_path}")
+ raise RuntimeError(
+ f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}"
+ )
return lib_found[0]
@@ -108,7 +110,9 @@ def find_include_path():
"""Find header files for C compilation."""
candidates = [
os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"),
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",
"include"),
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"
+ ),
]
for candidate in candidates:
if os.path.isdir(candidate):
@@ -130,12 +134,19 @@ def find_python_helper_include_path():
def find_dlpack_include_path():
"""Find dlpack header files for C compilation."""
- install_include_path =
os.path.join(os.path.dirname(os.path.realpath(__file__)), "include")
+ install_include_path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "include"
+ )
if os.path.isdir(os.path.join(install_include_path, "dlpack")):
return install_include_path
source_include_path = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "..", "..", "3rdparty",
"dlpack", "include"
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "3rdparty",
+ "dlpack",
+ "include",
)
if os.path.isdir(source_include_path):
return source_include_path
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 103956e..fbfb35d 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -200,7 +200,9 @@ class Module(core.Object):
b : Bool
True if the module is compilation exportable.
"""
- return (self.get_property_mask() &
ModulePropertyMask.COMPILATION_EXPORTABLE) != 0
+ return (
+ self.get_property_mask() &
ModulePropertyMask.COMPILATION_EXPORTABLE
+ ) != 0
def clear_imports(self):
"""Remove all imports of the module."""
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 2cd1ba1..f31dea3 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""FFI registry to register function and objects."""
+
import sys
from . import core
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 598afca..084dca8 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Stream context."""
+
from ctypes import c_void_p
from typing import Any, Optional, Union
diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py
index 3b3197e..55ab41f 100644
--- a/python/tvm_ffi/utils/lockfile.py
+++ b/python/tvm_ffi/utils/lockfile.py
@@ -64,7 +64,9 @@ class FileLock:
)
msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1)
else: # Unix-like systems
- self._file_descriptor = os.open(self.lock_file_path,
os.O_WRONLY | os.O_CREAT)
+ self._file_descriptor = os.open(
+ self.lock_file_path, os.O_WRONLY | os.O_CREAT
+ )
fcntl.flock(self._file_descriptor, fcntl.LOCK_EX |
fcntl.LOCK_NB)
return True
except (IOError, BlockingIOError):
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index d5d9bf5..48df954 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Helper tool to add ASF header to files that cannot be handled by Rat."""
+
import argparse
import fnmatch
import os
@@ -187,7 +188,9 @@ def get_git_files():
if result.returncode == 0:
return [line.strip() for line in result.stdout.split("\n") if
line.strip()]
else:
- print("Error: Could not get git files. Make sure you're in a git
repository.")
+ print(
+ "Error: Could not get git files. Make sure you're in a git
repository."
+ )
print("Git command failed:", result.stderr.strip())
return None
except FileNotFoundError:
@@ -343,7 +346,9 @@ Examples:
)
parser.add_argument(
- "--check", action="store_true", help="Check mode: report errors
without modifying files"
+ "--check",
+ action="store_true",
+ help="Check mode: report errors without modifying files",
)
parser.add_argument(
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index b44d5f1..d666470 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Helper tool to check file types that are allowed to checkin."""
+
import os
import subprocess
import sys
@@ -180,7 +181,7 @@ def main():
cmd = ["git", "ls-files"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
- assert proc.returncode == 0, f'{" ".join(cmd)} errored: {out}'
+ assert proc.returncode == 0, f"{' '.join(cmd)} errored: {out}"
res = out.decode("utf-8")
flist = res.split()
error_list = []
@@ -211,11 +212,14 @@ def main():
if asf_copyright_list:
report = "------File type check report----\n"
report += "\n".join(asf_copyright_list) + "\n"
- report += "------Found %d files that has ASF header with copyright
message----\n" % len(
- asf_copyright_list
+ report += (
+ "------Found %d files that has ASF header with copyright
message----\n"
+ % len(asf_copyright_list)
)
report += "--- Files with ASF header do not need Copyright lines.\n"
- report += "--- Contributors retain copyright to their contribution by
default.\n"
+ report += (
+ "--- Contributors retain copyright to their contribution by
default.\n"
+ )
report += "--- If a file comes with a different license, consider put
it under the 3rdparty folder instead.\n"
report += "---\n"
report += "--- You can use the following steps to remove the copyright
lines\n"
diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh
index 70b3c5b..fee4803 100755
--- a/tests/lint/git-clang-format.sh
+++ b/tests/lint/git-clang-format.sh
@@ -19,77 +19,74 @@ set -e
set -u
set -o pipefail
-
INPLACE_FORMAT=${INPLACE_FORMAT:=false}
LINT_ALL_FILES=true
REVISION=$(git rev-list --max-parents=0 HEAD)
-while (( $# )); do
- case "$1" in
- -i)
- INPLACE_FORMAT=true
- shift 1
- ;;
- --rev)
- LINT_ALL_FILES=false
- REVISION=$2
- shift 2
- ;;
- *)
- echo "Usage: tests/lint/git-clang-format.sh [-i] [--rev <commit>]"
- echo ""
- echo "Run clang-format on files that changed since <commit> or on
all files in the repo"
- echo "Examples:"
- echo "- Compare last one commit: tests/lint/git-clang-format.sh
--rev HEAD~1"
- echo "- Compare against upstream/main:
tests/lint/git-clang-format.sh --rev upstream/main"
- echo "The -i will use black to format files in-place instead of
checking them."
- exit 1
- ;;
- esac
+while (($#)); do
+ case "$1" in
+ -i)
+ INPLACE_FORMAT=true
+ shift 1
+ ;;
+ --rev)
+ LINT_ALL_FILES=false
+ REVISION=$2
+ shift 2
+ ;;
+ *)
+ echo "Usage: tests/lint/git-clang-format.sh [-i] [--rev
<commit>]"
+ echo ""
+ echo "Run clang-format on files that changed since <commit> or
on all files in the repo"
+ echo "Examples:"
+ echo "- Compare last one commit: tests/lint/git-clang-format.sh
--rev HEAD~1"
+ echo "- Compare against upstream/main:
tests/lint/git-clang-format.sh --rev upstream/main"
+ echo "The -i will use black to format files in-place instead of
checking them."
+ exit 1
+ ;;
+ esac
done
-
-cleanup()
-{
- if [ -f /tmp/$$.clang-format.txt ]; then
- echo ""
- echo "---------clang-format log----------"
- cat /tmp/$$.clang-format.txt
- fi
- rm -rf /tmp/$$.clang-format.txt
+cleanup() {
+ if [ -f /tmp/$$.clang-format.txt ]; then
+ echo ""
+ echo "---------clang-format log----------"
+ cat /tmp/$$.clang-format.txt
+ fi
+ rm -rf /tmp/$$.clang-format.txt
}
trap cleanup 0
CLANG_FORMAT=clang-format-15
if [ -x "$(command -v clang-format-15)" ]; then
- CLANG_FORMAT=clang-format-15
+ CLANG_FORMAT=clang-format-15
elif [ -x "$(command -v clang-format)" ]; then
- echo "clang-format might be different from clang-format-15, expect
potential difference."
- CLANG_FORMAT=clang-format
+ echo "clang-format might be different from clang-format-15, expect
potential difference."
+ CLANG_FORMAT=clang-format
else
- echo "Cannot find clang-format-15"
- exit 1
+ echo "Cannot find clang-format-15"
+ exit 1
fi
# Print out specific version
${CLANG_FORMAT} --version
if [[ "$INPLACE_FORMAT" == "true" ]]; then
- echo "Running inplace git-clang-format against $REVISION"
- git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT}
"$REVISION"
- exit 0
+ echo "Running inplace git-clang-format against $REVISION"
+ git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT}
"$REVISION"
+ exit 0
fi
if [[ "$LINT_ALL_FILES" == "true" ]]; then
- echo "Running git-clang-format against all C++ files"
- git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu
--binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt
+ echo "Running git-clang-format against all C++ files"
+ git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu
--binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt
else
- echo "Running git-clang-format against $REVISION"
- git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu
--binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt
+ echo "Running git-clang-format against $REVISION"
+ git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu
--binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt
fi
-if grep --quiet -E "diff" < /tmp/$$.clang-format.txt; then
- echo "clang-format lint error found. Consider running clang-format-15 on
these files to fix them."
- exit 1
+if grep --quiet -E "diff" </tmp/$$.clang-format.txt; then
+ echo "clang-format lint error found. Consider running clang-format-15
on these files to fix them."
+ exit 1
fi
diff --git a/tests/python/test_access_path.py b/tests/python/test_access_path.py
index d3f59fb..f70266b 100644
--- a/tests/python/test_access_path.py
+++ b/tests/python/test_access_path.py
@@ -94,16 +94,25 @@ def test_path_is_prefix_of():
assert not
AccessPath.root().attr("bar").is_prefix_of(AccessPath.root().attr("foo"))
# Shorter path is prefix of longer path with same start
- assert
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo").array_item(2))
+ assert (
+ AccessPath.root()
+ .attr("foo")
+ .is_prefix_of(AccessPath.root().attr("foo").array_item(2))
+ )
# Longer path is not prefix of shorter path
assert (
- not
AccessPath.root().attr("foo").array_item(2).is_prefix_of(AccessPath.root().attr("foo"))
+ not AccessPath.root()
+ .attr("foo")
+ .array_item(2)
+ .is_prefix_of(AccessPath.root().attr("foo"))
)
# Different paths are not prefixes
assert (
- not
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("bar").array_item(2))
+ not AccessPath.root()
+ .attr("foo")
+ .is_prefix_of(AccessPath.root().attr("bar").array_item(2))
)
@@ -124,10 +133,16 @@ def test_path_equal():
assert not (AccessPath.root().attr("bar") == AccessPath.root().attr("foo"))
# Shorter path does not equal longer path
- assert not (AccessPath.root().attr("foo") ==
AccessPath.root().attr("foo").array_item(2))
+ assert not (
+ AccessPath.root().attr("foo") ==
AccessPath.root().attr("foo").array_item(2)
+ )
# Longer path does not equal shorter path
- assert not (AccessPath.root().attr("foo").array_item(2) ==
AccessPath.root().attr("foo"))
+ assert not (
+ AccessPath.root().attr("foo").array_item(2) ==
AccessPath.root().attr("foo")
+ )
# Different paths are not equal
- assert not (AccessPath.root().attr("foo") ==
AccessPath.root().attr("bar").array_item(2))
+ assert not (
+ AccessPath.root().attr("foo") ==
AccessPath.root().attr("bar").array_item(2)
+ )
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 0ac4043..41a7985 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -71,7 +71,9 @@ def test_device_with_dev_id(dev_type, dev_id,
expected_device_type, expect_devic
assert dev.index == expect_device_id
[email protected]("dev_type, dev_id", [("cpu:0:0", None), ("cpu:?",
None), ("cpu:", None)])
[email protected](
+ "dev_type, dev_id", [("cpu:0:0", None), ("cpu:?", None), ("cpu:", None)]
+)
def test_deive_type_error(dev_type, dev_id):
with pytest.raises(ValueError):
tvm_ffi.device(dev_type, dev_id)
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index bab7c3e..0d40d17 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -38,7 +38,9 @@ def test_method():
def test_setter():
# test setter
- obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10,
v_str="hello")
+ obj0 = tvm_ffi.testing.create_object(
+ "testing.TestObjectBase", v_i64=10, v_str="hello"
+ )
assert obj0.v_i64 == 10
obj0.v_i64 = 11
assert obj0.v_i64 == 11
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index 954598c..96f43d3 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -36,6 +36,7 @@ Summary of some takeaways:
-
"""
+
import time
import numpy as np
@@ -182,27 +183,6 @@ def tvm_ffi_self_dlpack_nop(repeat):
bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(tvm)", x, y, z, repeat)
-def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
- """run dlpack conversion + tvm_ffi.nop
-
- Measures overhead of running dlpack for each args then invoke
- """
- nop = tvm_ffi.get_global_func("testing.nop")
- tx = tvm_ffi.from_dlpack(x)
- ty = tvm_ffi.from_dlpack(y)
- tz = tvm_ffi.from_dlpack(z)
- nop(tx, ty, tz)
-
- start = time.time()
- for i in range(repeat):
- tx = tvm_ffi.from_dlpack(x)
- ty = tvm_ffi.from_dlpack(y)
- tz = tvm_ffi.from_dlpack(z)
- nop(tx, ty, tz)
- end = time.time()
- print_speed(name, (end - start) / repeat)
-
-
def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
"""
Measures overhead of running dlpack for each args then invoke
@@ -238,7 +218,6 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
"""
nop = tvm_ffi.get_global_func("testing.nop")
nop(x, y, z)
- eps = 1e-6
start = time.time()
for i in range(repeat):
nop(x, y, z)
@@ -262,7 +241,9 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu",
stream=False):
f"tvm_ffi.nop.autodlpack(torch[{device}][stream])", x, y, z,
repeat
)
else:
-
bench_tvm_ffi_nop_autodlpack(f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y,
z, repeat)
+ bench_tvm_ffi_nop_autodlpack(
+ f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat
+ )
def tvm_ffi_nop_autodlpack_from_numpy(repeat):
@@ -367,7 +348,7 @@ def bench_torch_get_current_stream(repeat, name, func):
"""
Measures overhead of running torch.cuda.current_stream
"""
- x = torch.arange(1, device="cuda")
+ x = torch.arange(1, device="cuda") # noqa: F841
func(0)
start = time.time()
for i in range(repeat):
@@ -379,7 +360,9 @@ def bench_torch_get_current_stream(repeat, name, func):
def populate_object_table(num_classes):
nop = tvm_ffi.get_global_func("testing.nop")
- dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in
range(num_classes)]
+ dummy_instances = [
+ type(f"DummyClass{i}", (object,), {})() for i in range(num_classes)
+ ]
for instance in dummy_instances:
nop(instance)
@@ -418,15 +401,23 @@ def main():
print("---------------------------------------------------")
print("Benchmark x.__dlpack__(max_version=(1,1)) overhead")
print("---------------------------------------------------")
- bench_to_dlpack_versioned(torch.arange(1),
"torch.__dlpack__(max_version=(1,1))", repeat)
- bench_to_dlpack_versioned(np.arange(1),
"numpy.__dlpack__(max_version=(1,1))", repeat)
bench_to_dlpack_versioned(
- tvm_ffi.from_dlpack(torch.arange(1)),
"tvm.__dlpack__(max_version=(1,1))", repeat
+ torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat
+ )
+ bench_to_dlpack_versioned(
+ np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat
+ )
+ bench_to_dlpack_versioned(
+ tvm_ffi.from_dlpack(torch.arange(1)),
+ "tvm.__dlpack__(max_version=(1,1))",
+ repeat,
)
print("---------------------------------------------------")
print("Benchmark torch.get_cuda_stream[default stream]")
print("---------------------------------------------------")
- bench_torch_get_current_stream(repeat, "cpp-extension",
load_torch_get_current_cuda_stream())
+ bench_torch_get_current_stream(
+ repeat, "cpp-extension", load_torch_get_current_cuda_stream()
+ )
bench_torch_get_current_stream(repeat, "python",
torch_get_cuda_stream_native)
print("---------------------------------------------------")
print("Benchmark torch.get_cuda_stream[non-default stream]")
diff --git a/tests/scripts/task_cpp_tests.sh b/tests/scripts/task_cpp_tests.sh
index 27795cc..d7e935d 100755
--- a/tests/scripts/task_cpp_tests.sh
+++ b/tests/scripts/task_cpp_tests.sh
@@ -22,6 +22,6 @@ BUILD_TYPE=RelWithDebugInfo
rm -rf build/CMakeCache.txt
cmake -G Ninja -S . -B build -DTVM_FFI_BUILD_TESTS=ON
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
- -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
+ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
cmake --build build --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests
GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir build --output-on-failure
diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh
index 0d8aa4f..5b17cf8 100755
--- a/tests/scripts/task_lint.sh
+++ b/tests/scripts/task_lint.sh
@@ -18,30 +18,29 @@
set -euxo pipefail
-cleanup()
-{
- rm -rf /tmp/$$.*
+cleanup() {
+ rm -rf /tmp/$$.*
}
trap cleanup 0
function run_lint {
- echo "Checking file types..."
- python tests/lint/check_file_type.py
+ echo "Checking file types..."
+ python tests/lint/check_file_type.py
- echo "Checking ASF headers..."
- python tests/lint/check_asf_header.py --check
+ echo "Checking ASF headers..."
+ python tests/lint/check_asf_header.py --check
- echo "isort check..."
- isort --check --diff .
+ echo "isort check..."
+ isort --check --diff .
- echo "black check..."
- black --check --diff .
+ echo "black check..."
+ black --check --diff .
- echo "ruff check..."
- ruff check --diff .
+ echo "ruff check..."
+ ruff check --diff .
- echo "clang-format check..."
- tests/lint/git-clang-format.sh
+ echo "clang-format check..."
+ tests/lint/git-clang-format.sh
}
run_lint