This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d030ce27a1 [TVMScript] Optionally use `ruff format` instead of `black`
(#16876)
d030ce27a1 is described below
commit d030ce27a197e0a3e819b311dca5c5421d1cf5ba
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 17 00:04:10 2024 -0500
[TVMScript] Optionally use `ruff format` instead of `black` (#16876)
* [TVMScript] Optionally use `ruff format` instead of `black`
The `ruff format` tool is significantly faster than the `black`
formatter. For some particularly long TVMScript modules, using it can
reduce the time required to show a formatted module from ~5 minutes to
~1 minute. This commit updates the `.show()` function to apply the
optionally formatting using `ruff format` if available, falling back
to `black` otherwise.
* Fix lint error
---
python/tvm/script/highlight.py | 95 ++++++++++++++++++++++++++++++++++--------
1 file changed, 77 insertions(+), 18 deletions(-)
diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py
index be0de5a6bf..e017c1e6ca 100644
--- a/python/tvm/script/highlight.py
+++ b/python/tvm/script/highlight.py
@@ -17,7 +17,10 @@
"""Highlight printed TVM script.
"""
+import functools
import os
+import shutil
+import subprocess
import sys
import warnings
from typing import Any, Optional, Union
@@ -92,7 +95,73 @@ def cprint(
print(highlight(printable, Python3Lexer(),
Terminal256Formatter(style=style)))
-def _format(code_str: str) -> str:
[email protected]_cache
+def _get_formatter(formatter: Optional[str] = None):
+ def get_ruff_formatter():
+ if shutil.which("ruff") is None:
+ return None
+
+ def formatter(code_str):
+ proc = subprocess.Popen(
+ ["ruff", "format", "--stdin-filename=TVMScript"],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ )
+ stdout, _stderr = proc.communicate(code_str)
+ return stdout
+
+ return formatter
+
+ def get_black_formatter():
+ try:
+ # pylint: disable=import-outside-toplevel
+ import black
+ except ImportError:
+ return None
+
+ def formatter(code_str):
+ return black.format_str(code_str, mode=black.FileMode())
+
+ return formatter
+
+ def get_fallback_formatter():
+ def formatter(code_str):
+ with warnings.catch_warnings():
+ warnings.simplefilter("once", UserWarning)
+ ruff_install_cmd = sys.executable + " -m pip install ruff"
+ black_install_cmd = (
+ sys.executable + ' -m pip install "black==22.3.0"
--upgrade --user'
+ )
+ warnings.warn(
+ f"Neither the 'ruff' formatter nor the 'black' formatter
is available. "
+ f"To print formatted TVM script, please a formatter. \n"
+ f"To install ruff: {ruff_install_cmd}\n"
+ f"To install black: {black_install_cmd}",
+ category=UserWarning,
+ )
+ return code_str
+
+ return formatter
+
+ # formatter = "black"
+ if formatter is None:
+ options = [get_ruff_formatter, get_black_formatter]
+ elif formatter == "ruff":
+ options = [get_ruff_formatter]
+ elif formatter == "black":
+ options = [get_black_formatter]
+ else:
+ raise ValueError(f"Unknown formatter: {formatter}")
+
+ for option in options:
+ func = option()
+ if func is not None:
+ return func
+ return get_fallback_formatter()
+
+
+def _format(code_str: str, formatter: Optional[str] = None) -> str:
"""Format a code string using Black.
Parameters
@@ -101,29 +170,19 @@ def _format(code_str: str) -> str:
The string containing Python/TVMScript code to format
+ formatter: Optional[str]
+
+ The formatter to use. Can specify `ruff`, `black`, or
+ auto-select by passing `None`.
+
Returns
-------
formatted: str
The formatted Python/TVMScript code
+
"""
- try:
- # pylint: disable=import-outside-toplevel
- import black
- except ImportError as err:
- with warnings.catch_warnings():
- warnings.simplefilter("once", UserWarning)
- install_cmd = sys.executable + ' -m pip install "black==22.3.0"
--upgrade --user'
- warnings.warn(
- str(err)
- + "\n"
- + "To print formatted TVM script, please install the formatter
'Black':\n"
- + install_cmd,
- category=UserWarning,
- )
- return code_str
- else:
- return black.format_str(code_str, mode=black.FileMode())
+ return _get_formatter(formatter)(code_str)
def _get_pygments_style(