This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d030ce27a1 [TVMScript] Optionally use `ruff format` instead of `black` 
(#16876)
d030ce27a1 is described below

commit d030ce27a197e0a3e819b311dca5c5421d1cf5ba
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 17 00:04:10 2024 -0500

    [TVMScript] Optionally use `ruff format` instead of `black` (#16876)
    
    * [TVMScript] Optionally use `ruff format` instead of `black`
    
    The `ruff format` tool is significantly faster than the `black`
    formatter.  For some particularly long TVMScript modules, using it can
    reduce the time required to show a formatted module from ~5 minutes to
    ~1 minute.  This commit updates the `.show()` function to apply the
    optionally formatting using `ruff format` if available, falling back
    to `black` otherwise.
    
    * Fix lint error
---
 python/tvm/script/highlight.py | 95 ++++++++++++++++++++++++++++++++++--------
 1 file changed, 77 insertions(+), 18 deletions(-)

diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py
index be0de5a6bf..e017c1e6ca 100644
--- a/python/tvm/script/highlight.py
+++ b/python/tvm/script/highlight.py
@@ -17,7 +17,10 @@
 """Highlight printed TVM script.
 """
 
+import functools
 import os
+import shutil
+import subprocess
 import sys
 import warnings
 from typing import Any, Optional, Union
@@ -92,7 +95,73 @@ def cprint(
         print(highlight(printable, Python3Lexer(), 
Terminal256Formatter(style=style)))
 
 
-def _format(code_str: str) -> str:
[email protected]_cache
+def _get_formatter(formatter: Optional[str] = None):
+    def get_ruff_formatter():
+        if shutil.which("ruff") is None:
+            return None
+
+        def formatter(code_str):
+            proc = subprocess.Popen(
+                ["ruff", "format", "--stdin-filename=TVMScript"],
+                stdin=subprocess.PIPE,
+                stdout=subprocess.PIPE,
+                encoding="utf-8",
+            )
+            stdout, _stderr = proc.communicate(code_str)
+            return stdout
+
+        return formatter
+
+    def get_black_formatter():
+        try:
+            # pylint: disable=import-outside-toplevel
+            import black
+        except ImportError:
+            return None
+
+        def formatter(code_str):
+            return black.format_str(code_str, mode=black.FileMode())
+
+        return formatter
+
+    def get_fallback_formatter():
+        def formatter(code_str):
+            with warnings.catch_warnings():
+                warnings.simplefilter("once", UserWarning)
+                ruff_install_cmd = sys.executable + " -m pip install ruff"
+                black_install_cmd = (
+                    sys.executable + ' -m pip install "black==22.3.0" 
--upgrade --user'
+                )
+                warnings.warn(
+                    f"Neither the 'ruff' formatter nor the 'black' formatter 
is available.  "
+                    f"To print formatted TVM script, please a formatter.  \n"
+                    f"To install ruff: {ruff_install_cmd}\n"
+                    f"To install black: {black_install_cmd}",
+                    category=UserWarning,
+                )
+            return code_str
+
+        return formatter
+
+    # formatter = "black"
+    if formatter is None:
+        options = [get_ruff_formatter, get_black_formatter]
+    elif formatter == "ruff":
+        options = [get_ruff_formatter]
+    elif formatter == "black":
+        options = [get_black_formatter]
+    else:
+        raise ValueError(f"Unknown formatter: {formatter}")
+
+    for option in options:
+        func = option()
+        if func is not None:
+            return func
+    return get_fallback_formatter()
+
+
+def _format(code_str: str, formatter: Optional[str] = None) -> str:
     """Format a code string using Black.
 
     Parameters
@@ -101,29 +170,19 @@ def _format(code_str: str) -> str:
 
         The string containing Python/TVMScript code to format
 
+    formatter: Optional[str]
+
+        The formatter to use.  Can specify `ruff`, `black`, or
+        auto-select by passing `None`.
+
     Returns
     -------
     formatted: str
 
         The formatted Python/TVMScript code
+
     """
-    try:
-        # pylint: disable=import-outside-toplevel
-        import black
-    except ImportError as err:
-        with warnings.catch_warnings():
-            warnings.simplefilter("once", UserWarning)
-            install_cmd = sys.executable + ' -m pip install "black==22.3.0" 
--upgrade --user'
-            warnings.warn(
-                str(err)
-                + "\n"
-                + "To print formatted TVM script, please install the formatter 
'Black':\n"
-                + install_cmd,
-                category=UserWarning,
-            )
-        return code_str
-    else:
-        return black.format_str(code_str, mode=black.FileMode())
+    return _get_formatter(formatter)(code_str)
 
 
 def _get_pygments_style(

Reply via email to