This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff8e41644f [TVMScript] Avoid segfault from invalid TVMScript (#17373)
ff8e41644f is described below
commit ff8e41644fde86714d6dbf021d57baebe3a1ec1a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Sep 17 09:07:41 2024 -0500
[TVMScript] Avoid segfault from invalid TVMScript (#17373)
* [TVMScript] Avoid segfault from invalid TVMScript
Prior to this commit, after the `DiagnosticContext` prints its error,
it overwrites the `DiagnosticRenderer` with a NULL renderer. If a
second call to `DiagnosticContext::Render` occurs, it will segfault.
This appears to be intended to prevent double-printing of error
messages, but double-printing error messages is much worse than a
segfault.
In addition, `DiagnosticContext::Render` should only be called once.
There's a common pattern in the parser where it will wrap exceptions
in `DiagnosticError`, but re-raise exceptions that are already a
`DiagnosticError`. This requires every such location to include
`except DiagnosticError: raise`, and can easily be missed.
This PR makes two changes: First, the `DiagnosticRenderer` is updated
to have a no-op callback rather than a NULL callback. Second, the
re-raising of `DiagnosticError` is moved to `Parser.report_error`, so
that it does not need to be handled separately at several independent
locations in the TVMScript parser.
---
python/tvm/script/parser/core/evaluator.py | 12 ++++++------
python/tvm/script/parser/core/parser.py | 19 ++++++++++---------
python/tvm/script/parser/relax/parser.py | 10 +++++-----
src/ir/diagnostic.cc | 3 ++-
tests/python/relax/test_tvmscript_parser.py | 14 +++++++++++---
.../tvmscript/test_tvmscript_printer_highlight.py | 8 +++++---
6 files changed, 39 insertions(+), 27 deletions(-)
diff --git a/python/tvm/script/parser/core/evaluator.py
b/python/tvm/script/parser/core/evaluator.py
index 26e9d091bf..7a194c779d 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -267,8 +267,8 @@ class ExprEvaluator:
value = self._eval_slice(fields)
else:
value = self._eval_expr(node.__class__(**fields))
- except Exception as e: # pylint: disable=broad-except,invalid-name
- self.parser.report_error(node, e)
+ except Exception as err: # pylint: disable=broad-except
+ self.parser.report_error(node, err)
return self._add_intermediate_result(value)
def _eval_lambda(self, node: doc.Lambda) -> Any:
@@ -286,8 +286,8 @@ class ExprEvaluator:
"""
try:
value = self._eval_expr(node)
- except Exception as e: # pylint: disable=broad-except,invalid-name
- self.parser.report_error(node, str(e))
+ except Exception as err: # pylint: disable=broad-except
+ self.parser.report_error(node, err)
return self._add_intermediate_result(value)
def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
@@ -463,8 +463,8 @@ def eval_assign(
"""
try:
return _eval_assign(target, source)
- except Exception as e: # pylint: disable=broad-except,invalid-name
- parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+ except Exception as err: # pylint: disable=broad-except
+ parser.report_error(target, err)
raise
diff --git a/python/tvm/script/parser/core/parser.py
b/python/tvm/script/parser/core/parser.py
index 0ecf669566..372a3c54e4 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -307,10 +307,8 @@ def _dispatch_wrapper(func: dispatch.ParseMethod) ->
dispatch.ParseMethod:
def _wrapper(self: "Parser", node: doc.AST) -> None:
try:
return func(self, node)
- except DiagnosticError:
- raise
- except Exception as e: # pylint: disable=broad-except,invalid-name
- self.report_error(node, e)
+ except Exception as err: # pylint: disable=broad-except
+ self.report_error(node, err)
raise
return _wrapper
@@ -547,6 +545,12 @@ class Parser(doc.NodeVisitor):
err: Union[Exception, str]
The error to report.
"""
+
+ # If the error is already being raised as a DiagnosticError,
+ # re-raise it without wrapping it in a DiagnosticContext.
+ if isinstance(err, DiagnosticError):
+ raise err
+
# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
@@ -595,11 +599,8 @@ class Parser(doc.NodeVisitor):
raise NotImplementedError(f"Visitor of AST node is not
implemented: {name}")
try:
func(node)
- except DiagnosticError:
- raise
- except Exception as e: # pylint: disable=broad-except,invalid-name
- self.report_error(node, str(e))
- raise
+ except Exception as err: # pylint: disable=broad-except
+ self.report_error(node, err)
def visit_body(self, node: List[doc.stmt]) -> Any:
"""The general body visiting method.
diff --git a/python/tvm/script/parser/relax/parser.py
b/python/tvm/script/parser/relax/parser.py
index 08269ddeeb..011136d5d3 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -104,9 +104,9 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) ->
StructInfoProxy:
try:
annotation = self.eval_expr(node)
return _normalize_struct_info_proxy(annotation)
- except Exception as err:
- self.report_error(node, str(err))
- raise err
+ except Exception as err: # pylint: disable=broad-except
+ self.report_error(node, err)
+ raise
def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) ->
StructInfo:
@@ -114,9 +114,9 @@ def eval_struct_info(self: Parser, node: doc.expr,
eval_str: bool = False) -> St
try:
struct_info = self.eval_expr(node)
return _normalize_struct_info(struct_info, var_table)
- except Exception as err:
+ except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
- raise err
+ raise
def is_called(node: Any, func_name: str) -> bool:
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index 9245ec9c0b..8eeb4b3e6f 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -127,7 +127,8 @@ void DiagnosticContext::Render() {
}
if (errs) {
- (*this)->renderer = DiagnosticRenderer();
+ (*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {});
+ // (*this)->diagnostics.clear();
LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
<< "emitted, please check diagnostic render for output.";
}
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 64f2efd4af..fd465f3201 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -179,6 +179,15 @@ def test_unassigned_call_fail():
return x
+def test_incorrect_tensor_shape():
+ with pytest.raises(tvm.error.DiagnosticError):
+
+ @R.function
+ def f(x: R.Tensor([16])):
+ y: R.Tensor(16) = R.add(x, x)
+ return y
+
+
def test_simple_module():
@I.ir_module
class TestModule:
@@ -1045,7 +1054,6 @@ def test_call_tir_inplace():
def test_call_tir_inplace_with_tuple_var_raises_error():
-
with pytest.raises(tvm.error.DiagnosticError):
@tvm.script.ir_module
@@ -1838,7 +1846,7 @@ def test_class_normalize():
_check(InputModule, OutputModule)
-def test_context_aware_parsing():
+def test_context_aware_parsing(monkeypatch):
@tvm.script.ir_module
class Module:
@T.prim_func
@@ -1863,7 +1871,7 @@ def test_context_aware_parsing():
def _break_env(self, *args):
raise RuntimeError("Fail to pass context-aware parsing")
- tvm.ir.GlobalVar.__call__ = _break_env
+ monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env)
_check(Module)
diff --git a/tests/python/tvmscript/test_tvmscript_printer_highlight.py
b/tests/python/tvmscript/test_tvmscript_printer_highlight.py
index 16e90c3563..4c33b435f0 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_highlight.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_highlight.py
@@ -21,7 +21,7 @@ import tvm
import tvm.testing
from tvm import relay
from tvm.script import tir as T
-from tvm.script.highlight import cprint
+from tvm.script.highlight import cprint, _format
def test_highlight_script():
@@ -58,12 +58,14 @@ def test_cprint():
# Print nodes with `script` method, e.g. PrimExpr
cprint(tvm.tir.Var("v", "int32") + 1)
- # Cannot print non-Python-style codes if black installed
+ # Cannot print non-Python-style codes when using the black
+ # formatter. This error comes from `_format`, used internally by
+ # `cprint`, and doesn't occur when using the `ruff` formatter.
try:
import black
with pytest.raises(ValueError):
- cprint("if (a == 1) { a +=1; }")
+ _format("if (a == 1) { a +=1; }", formatter="black")
except ImportError:
pass