This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff8e41644f [TVMScript] Avoid segfault from invalid TVMScript (#17373)
ff8e41644f is described below

commit ff8e41644fde86714d6dbf021d57baebe3a1ec1a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Sep 17 09:07:41 2024 -0500

    [TVMScript] Avoid segfault from invalid TVMScript (#17373)
    
    * [TVMScript] Avoid segfault from invalid TVMScript
    
    Prior to this commit, after the `DiagnosticContext` prints its error,
    it overwrites the `DiagnosticRenderer` with a NULL renderer.  If a
    second call to `DiagnosticContext::Render` occurs, it will segfault.
    This appears to be intended to prevent double-printing of error
    messages, but double-printing error messages is much worse than a
    segfault.
    
    In addition, `DiagnosticContext::Render` should only be called once.
    There's a common pattern in the parser where it will wrap exceptions
    in `DiagnosticError`, but re-raise exceptions that are already a
    `DiagnosticError`.  This requires every such location to include
    `except DiagnosticError: raise`, and can easily be missed.
    
    This PR makes two changes: First, the `DiagnosticRenderer` is updated
    to have a no-op callback rather than a NULL callback.  Second, the
    re-raising of `DiagnosticError` is moved to `Parser.report_error`, so
    that it does not need to be handled separately at several independent
    locations in the TVMScript parser.
---
 python/tvm/script/parser/core/evaluator.py            | 12 ++++++------
 python/tvm/script/parser/core/parser.py               | 19 ++++++++++---------
 python/tvm/script/parser/relax/parser.py              | 10 +++++-----
 src/ir/diagnostic.cc                                  |  3 ++-
 tests/python/relax/test_tvmscript_parser.py           | 14 +++++++++++---
 .../tvmscript/test_tvmscript_printer_highlight.py     |  8 +++++---
 6 files changed, 39 insertions(+), 27 deletions(-)

diff --git a/python/tvm/script/parser/core/evaluator.py 
b/python/tvm/script/parser/core/evaluator.py
index 26e9d091bf..7a194c779d 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -267,8 +267,8 @@ class ExprEvaluator:
                 value = self._eval_slice(fields)
             else:
                 value = self._eval_expr(node.__class__(**fields))
-        except Exception as e:  # pylint: disable=broad-except,invalid-name
-            self.parser.report_error(node, e)
+        except Exception as err:  # pylint: disable=broad-except
+            self.parser.report_error(node, err)
         return self._add_intermediate_result(value)
 
     def _eval_lambda(self, node: doc.Lambda) -> Any:
@@ -286,8 +286,8 @@ class ExprEvaluator:
         """
         try:
             value = self._eval_expr(node)
-        except Exception as e:  # pylint: disable=broad-except,invalid-name
-            self.parser.report_error(node, str(e))
+        except Exception as err:  # pylint: disable=broad-except
+            self.parser.report_error(node, err)
         return self._add_intermediate_result(value)
 
     def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
@@ -463,8 +463,8 @@ def eval_assign(
     """
     try:
         return _eval_assign(target, source)
-    except Exception as e:  # pylint: disable=broad-except,invalid-name
-        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+    except Exception as err:  # pylint: disable=broad-except
+        parser.report_error(target, err)
         raise
 
 
diff --git a/python/tvm/script/parser/core/parser.py 
b/python/tvm/script/parser/core/parser.py
index 0ecf669566..372a3c54e4 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -307,10 +307,8 @@ def _dispatch_wrapper(func: dispatch.ParseMethod) -> 
dispatch.ParseMethod:
     def _wrapper(self: "Parser", node: doc.AST) -> None:
         try:
             return func(self, node)
-        except DiagnosticError:
-            raise
-        except Exception as e:  # pylint: disable=broad-except,invalid-name
-            self.report_error(node, e)
+        except Exception as err:  # pylint: disable=broad-except
+            self.report_error(node, err)
             raise
 
     return _wrapper
@@ -547,6 +545,12 @@ class Parser(doc.NodeVisitor):
         err: Union[Exception, str]
             The error to report.
         """
+
+        # If the error is already being raised as a DiagnosticError,
+        # re-raise it without wrapping it in a DiagnosticContext.
+        if isinstance(err, DiagnosticError):
+            raise err
+
         # Only take the last line of the error message
         if isinstance(err, TVMError):
             msg = list(filter(None, str(err).split("\n")))[-1]
@@ -595,11 +599,8 @@ class Parser(doc.NodeVisitor):
             raise NotImplementedError(f"Visitor of AST node is not 
implemented: {name}")
         try:
             func(node)
-        except DiagnosticError:
-            raise
-        except Exception as e:  # pylint: disable=broad-except,invalid-name
-            self.report_error(node, str(e))
-            raise
+        except Exception as err:  # pylint: disable=broad-except
+            self.report_error(node, err)
 
     def visit_body(self, node: List[doc.stmt]) -> Any:
         """The general body visiting method.
diff --git a/python/tvm/script/parser/relax/parser.py 
b/python/tvm/script/parser/relax/parser.py
index 08269ddeeb..011136d5d3 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -104,9 +104,9 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> 
StructInfoProxy:
     try:
         annotation = self.eval_expr(node)
         return _normalize_struct_info_proxy(annotation)
-    except Exception as err:
-        self.report_error(node, str(err))
-        raise err
+    except Exception as err:  # pylint: disable=broad-except
+        self.report_error(node, err)
+        raise
 
 
 def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> 
StructInfo:
@@ -114,9 +114,9 @@ def eval_struct_info(self: Parser, node: doc.expr, 
eval_str: bool = False) -> St
     try:
         struct_info = self.eval_expr(node)
         return _normalize_struct_info(struct_info, var_table)
-    except Exception as err:
+    except Exception as err:  # pylint: disable=broad-except
         self.report_error(node, err)
-        raise err
+        raise
 
 
 def is_called(node: Any, func_name: str) -> bool:
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index 9245ec9c0b..8eeb4b3e6f 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -127,7 +127,8 @@ void DiagnosticContext::Render() {
   }
 
   if (errs) {
-    (*this)->renderer = DiagnosticRenderer();
+    (*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {});
+    // (*this)->diagnostics.clear();
     LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
                << "emitted, please check diagnostic render for output.";
   }
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 64f2efd4af..fd465f3201 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -179,6 +179,15 @@ def test_unassigned_call_fail():
             return x
 
 
+def test_incorrect_tensor_shape():
+    with pytest.raises(tvm.error.DiagnosticError):
+
+        @R.function
+        def f(x: R.Tensor([16])):
+            y: R.Tensor(16) = R.add(x, x)
+            return y
+
+
 def test_simple_module():
     @I.ir_module
     class TestModule:
@@ -1045,7 +1054,6 @@ def test_call_tir_inplace():
 
 
 def test_call_tir_inplace_with_tuple_var_raises_error():
-
     with pytest.raises(tvm.error.DiagnosticError):
 
         @tvm.script.ir_module
@@ -1838,7 +1846,7 @@ def test_class_normalize():
     _check(InputModule, OutputModule)
 
 
-def test_context_aware_parsing():
+def test_context_aware_parsing(monkeypatch):
     @tvm.script.ir_module
     class Module:
         @T.prim_func
@@ -1863,7 +1871,7 @@ def test_context_aware_parsing():
     def _break_env(self, *args):
         raise RuntimeError("Fail to pass context-aware parsing")
 
-    tvm.ir.GlobalVar.__call__ = _break_env
+    monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env)
 
     _check(Module)
 
diff --git a/tests/python/tvmscript/test_tvmscript_printer_highlight.py 
b/tests/python/tvmscript/test_tvmscript_printer_highlight.py
index 16e90c3563..4c33b435f0 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_highlight.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_highlight.py
@@ -21,7 +21,7 @@ import tvm
 import tvm.testing
 from tvm import relay
 from tvm.script import tir as T
-from tvm.script.highlight import cprint
+from tvm.script.highlight import cprint, _format
 
 
 def test_highlight_script():
@@ -58,12 +58,14 @@ def test_cprint():
     # Print nodes with `script` method, e.g. PrimExpr
     cprint(tvm.tir.Var("v", "int32") + 1)
 
-    # Cannot print non-Python-style codes if black installed
+    # Cannot print non-Python-style codes when using the black
+    # formatter.  This error comes from `_format`, used internally by
+    # `cprint`, and doesn't occur when using the `ruff` formatter.
     try:
         import black
 
         with pytest.raises(ValueError):
-            cprint("if (a == 1) { a +=1; }")
+            _format("if (a == 1) { a +=1; }", formatter="black")
     except ImportError:
         pass
 

Reply via email to