This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d5a4f66fdc [FFI] Propagate Python errors across FFI boundaries (#15596)
d5a4f66fdc is described below

commit d5a4f66fdc7008805c50550d6cfbfac79b9e8902
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 7 22:26:22 2023 -0700

    [FFI] Propagate Python errors across FFI boundaries (#15596)
    
    * [Runtime] Re-organize BacktraceFullCallback
    
    Prior to this commit, the `BacktraceFullCallback` function returned
    zero for frames that should be excluded from the stack trace, even if
    they were the `"TVMFuncCall"` that should terminate the stack trace.
    
    This commit re-organized `BacktraceFullCallback`, moving the
    terminating checks to occur prior to the suppression checks, and
    adding comments to indicate why each suppression is present.
    
    * [FFI] Propagate Python errors across FFI boundaries
    
    Prior to this commit, if a Python script passes a callback to a C++
    function, and that callback raises an exception, the original
    exception cannot be caught in the outer python script.  As a result,
    interactive debugging, such as done with `pdb` or `ipdb`, could only
    inspect stack frames outside the first Python to C++ FFI call.
    
    This commit updates the FFI API to propagate the Python exception
    through an FFI boundary.  As a result, all Python frames in the stack
    trace can be inspected.
    
    * Updated unit tests that depended on exception coercion.
    
    Previously, Python exceptions were coerced to `tvm.error.TVMError` if
    no corresponding error type had been registered with
    `tvm._ffi.register_error`.  With the pass-through of Python
    exceptions, this coercion no longer applies.  Unit tests that relied
    on this coercion needed to be updated as a result.
    
    
    ---------
    
    Co-authored-by: Chris Sullivan <[email protected]>
---
 include/tvm/runtime/c_runtime_api.h                |   7 +
 include/tvm/runtime/registry.h                     |  45 +++++++
 python/tvm/_ffi/_ctypes/packed_func.py             |  26 ++--
 python/tvm/_ffi/_cython/base.pxi                   |   5 +-
 python/tvm/_ffi/_cython/packed_func.pxi            |  19 ++-
 python/tvm/_ffi/base.py                            | 149 ++++++++++++++++++++-
 src/ir/transform.cc                                |  88 ++++++++----
 src/relay/analysis/type_solver.cc                  |   2 -
 src/runtime/c_runtime_api.cc                       |  94 ++++++++++++-
 src/runtime/logging.cc                             | 137 +++++++++++++------
 src/runtime/registry.cc                            |  62 ++++++++-
 src/support/ffi_testing.cc                         |  12 ++
 tests/python/relay/test_pass_instrument.py         |  16 +--
 tests/python/relay/test_type_infer.py              |   2 +-
 ...eta_schedule_schedule_rule_apply_custom_rule.py |   2 +-
 tests/python/unittest/test_runtime_error.py        | 102 ++++++++++++--
 16 files changed, 653 insertions(+), 115 deletions(-)

diff --git a/include/tvm/runtime/c_runtime_api.h 
b/include/tvm/runtime/c_runtime_api.h
index 36ae5c6b15..43cf499481 100644
--- a/include/tvm/runtime/c_runtime_api.h
+++ b/include/tvm/runtime/c_runtime_api.h
@@ -244,6 +244,13 @@ typedef void* TVMObjectHandle;
  */
 TVM_DLL void TVMAPISetLastError(const char* msg);
 
+/*!
+ * \brief Used for implementing C API function.
+ *  Set last exception before return.
+ * \param py_object The python exception to be set
+ */
+TVM_DLL void TVMAPISetLastPythonError(void* py_object);
+
 /*!
  * \brief return str message of the last error
  *  all function in this file will return 0 when success
diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h
index 3a1e86e87f..71ea9f4a34 100644
--- a/include/tvm/runtime/registry.h
+++ b/include/tvm/runtime/registry.h
@@ -97,6 +97,51 @@ namespace runtime {
  */
 TVM_DLL void EnvCheckSignals();
 
+/*! \brief A class that wraps a Python object and preserves its ownership.
+
+ * This class is used to wrap a PyObject* from the Python API and preserve its 
ownership.
+ * Allows for the creation of strong references to Python objects, which 
prevent them from being
+ * garbage-collected as long as the wrapper object exists.
+ */
+class WrappedPythonObject {
+ public:
+  /*! \brief Construct a wrapper that doesn't own anything */
+  WrappedPythonObject() : python_obj_(nullptr) {}
+
+  /*! \brief Conversion constructor from nullptr */
+  explicit WrappedPythonObject(std::nullptr_t) : python_obj_(nullptr) {}
+
+  /*! \brief Take ownership of a python object
+   *
+   * A new strong reference is created for the underlying python
+   * object.
+   *
+   * \param python_obj A PyObject* from the Python.h API.  A new
+   * strong reference is created using Py_IncRef.
+   */
+  explicit WrappedPythonObject(void* python_obj);
+
+  /*! \brief Drop ownership of a python object
+   *
+   * Removes the strong reference held by the wrapper.
+   */
+  ~WrappedPythonObject();
+
+  WrappedPythonObject(WrappedPythonObject&&);
+  WrappedPythonObject& operator=(WrappedPythonObject&&);
+
+  WrappedPythonObject(const WrappedPythonObject&);
+  WrappedPythonObject& operator=(const WrappedPythonObject&);
+  WrappedPythonObject& operator=(std::nullptr_t);
+
+  operator bool() { return python_obj_; }
+
+  void* raw_pointer() { return python_obj_; }
+
+ private:
+  void* python_obj_ = nullptr;
+};
+
 /*! \brief Registry for global function */
 class Registry {
  public:
diff --git a/python/tvm/_ffi/_ctypes/packed_func.py 
b/python/tvm/_ffi/_ctypes/packed_func.py
index 32ffe3d8c6..e8680afcdf 100644
--- a/python/tvm/_ffi/_ctypes/packed_func.py
+++ b/python/tvm/_ffi/_ctypes/packed_func.py
@@ -22,7 +22,7 @@ import ctypes
 import traceback
 from numbers import Number, Integral
 
-from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
+from ..base import _LIB, get_last_ffi_error, py2cerror, check_call, 
raise_last_ffi_error
 from ..base import c_str, string_types
 from ..runtime_ctypes import DataType, TVMByteArray, Device, ObjectRValueRef
 from . import ndarray as _nd
@@ -80,10 +80,11 @@ def convert_to_tvm_func(pyfunc):
         # pylint: disable=broad-except
         try:
             rv = local_pyfunc(*pyargs)
-        except Exception:
+        except Exception as err:
             msg = traceback.format_exc()
             msg = py2cerror(msg)
-            _LIB.TVMAPISetLastError(c_str(msg))
+            _LIB.TVMAPISetLastPythonError(ctypes.py_object(err))
+
             return -1
 
         if rv is not None:
@@ -94,7 +95,7 @@ def convert_to_tvm_func(pyfunc):
             if not isinstance(ret, TVMRetValueHandle):
                 ret = TVMRetValueHandle(ret)
             if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 
0:
-                raise get_last_ffi_error()
+                raise_last_ffi_error()
             _ = temp_args
             _ = rv
         return 0
@@ -106,7 +107,7 @@ def convert_to_tvm_func(pyfunc):
     pyobj = ctypes.py_object(f)
     ctypes.pythonapi.Py_IncRef(pyobj)
     if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, 
ctypes.byref(handle)) != 0:
-        raise get_last_ffi_error()
+        raise_last_ffi_error()
     return _make_packed_func(handle, False)
 
 
@@ -212,7 +213,7 @@ class PackedFuncBase(object):
     def __del__(self):
         if not self.is_global and _LIB is not None:
             if _LIB.TVMFuncFree(self.handle) != 0:
-                raise get_last_ffi_error()
+                raise_last_ffi_error()
 
     def __call__(self, *args):
         """Call the function with positional arguments
@@ -235,7 +236,7 @@ class PackedFuncBase(object):
             )
             != 0
         ):
-            raise get_last_ffi_error()
+            raise_last_ffi_error()
         _ = temp_args
         _ = args
         return RETURN_SWITCH[ret_tcode.value](ret_val)
@@ -258,7 +259,7 @@ def __init_handle_by_constructor__(fconstructor, args):
         )
         != 0
     ):
-        raise get_last_ffi_error()
+        raise_last_ffi_error()
     _ = temp_args
     _ = args
     assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE
@@ -333,3 +334,12 @@ def _set_class_object_generic(object_generic_class, 
func_convert_to_object):
     global _FUNC_CONVERT_TO_OBJECT
     _CLASS_OBJECT_GENERIC = object_generic_class
     _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
+
+
+def _init_pythonapi_inc_def_ref():
+    register_func = _LIB.TVMBackendRegisterEnvCAPI
+    register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef)
+    register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef)
+
+
+_init_pythonapi_inc_def_ref()
diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi
index c2c0667497..69e1355f7d 100644
--- a/python/tvm/_ffi/_cython/base.pxi
+++ b/python/tvm/_ffi/_cython/base.pxi
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from ..base import get_last_ffi_error
+from ..base import raise_last_ffi_error
 from libcpp.vector cimport vector
 from cpython.version cimport PY_MAJOR_VERSION
 from cpython cimport pycapsule
@@ -113,6 +113,7 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* 
resource_handle)
 # We mark the possibly long running function as nogil below.
 cdef extern from "tvm/runtime/c_runtime_api.h":
     void TVMAPISetLastError(const char* msg)
+    void TVMAPISetLastPythonError(void* py_object) except +
     const char *TVMGetLastError()
     int TVMFuncGetGlobal(const char* name,
                          TVMPackedFuncHandle* out)
@@ -178,7 +179,7 @@ cdef inline int CHECK_CALL(int ret) except -2:
     if ret == -2:
         return -2
     if ret != 0:
-        raise get_last_ffi_error()
+        raise_last_ffi_error()
     return 0
 
 
diff --git a/python/tvm/_ffi/_cython/packed_func.pxi 
b/python/tvm/_ffi/_cython/packed_func.pxi
index 7c9ef51bd6..ae528bcb78 100644
--- a/python/tvm/_ffi/_cython/packed_func.pxi
+++ b/python/tvm/_ffi/_cython/packed_func.pxi
@@ -54,10 +54,11 @@ cdef int tvm_callback(TVMValue* args,
             pyargs.append(c_make_array(value.v_handle, True, False))
     try:
         rv = local_pyfunc(*pyargs)
-    except Exception:
+    except Exception as err:
         msg = traceback.format_exc()
         msg = py2cerror(msg)
-        TVMAPISetLastError(c_str(msg))
+        TVMAPISetLastPythonError(<void*>err)
+
         return -1
     if rv is not None:
         if isinstance(rv, tuple):
@@ -368,3 +369,17 @@ def _set_class_object_generic(object_generic_class, 
func_convert_to_object):
     global _FUNC_CONVERT_TO_OBJECT
     _CLASS_OBJECT_GENERIC = object_generic_class
     _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
+
+# Py_INCREF and Py_DECREF are C macros, not function objects.
+# Therefore, providing a wrapper function.
+cdef void _py_incref_wrapper(void* py_object):
+    Py_INCREF(<object>py_object)
+cdef void _py_decref_wrapper(void* py_object):
+    Py_DECREF(<object>py_object)
+
+def _init_pythonapi_inc_def_ref():
+    register_func = TVMBackendRegisterEnvCAPI
+    register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
+    register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)
+
+_init_pythonapi_inc_def_ref()
diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py
index 744e4c93e1..f0eddf8b36 100644
--- a/python/tvm/_ffi/base.py
+++ b/python/tvm/_ffi/base.py
@@ -17,10 +17,17 @@
 # coding: utf-8
 # pylint: disable=invalid-name, import-outside-toplevel
 """Base library for TVM FFI."""
-import sys
-import os
 import ctypes
+import functools
+import os
+import re
+import sys
+import types
+
+from typing import Callable, Sequence
+
 import numpy as np
+
 from . import libinfo
 
 # ----------------------------
@@ -333,6 +340,142 @@ def get_last_ffi_error():
     return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
 
 
+def _append_traceback_frame(tb, func_name, filepath, lineno):
+    """Append a dummy frame to appear in the Python traceback"""
+
+    # Compile a dummy function to Python bytecode, so that with the
+    # filepath that we want to appear in the traceback.  Any external
+    # debugger (e.g. pdb) that catches the exception will use the
+    # filepath to show code snippets from that FFI file.
+    code = compile(
+        "{}def dummy_func(): raise NotImplementedError()".format("\n" * 
(lineno - 1)),
+        filepath,
+        "exec",
+    )
+
+    # Replacing the name by updating the bytecode allows the function
+    # name to be values that would normally be forbidden by python
+    # syntax.  For example, "operator()".
+    code = 
code.replace(co_consts=(code.co_consts[0].replace(co_name=func_name), 
func_name, None))
+    namespace = {}
+    exec(code, namespace)  # pylint: disable=exec-used
+    dummy_func = namespace["dummy_func"]
+
+    # Execute the dummy function in order to generate a stack frame.
+    dummy_tb = None
+    try:
+        dummy_func()
+    except NotImplementedError as err:
+        dummy_tb = err.__traceback__
+
+    # Insert the dummy function into the stack trace.
+    new_frame = dummy_tb.tb_next
+    return types.TracebackType(tb, new_frame.tb_frame, new_frame.tb_lasti, 
new_frame.tb_lineno)
+
+
+def _filter_traceback_frames(tb, filter_funcs: 
Sequence[Callable[[types.CodeType], bool]]):
+    orig = tb
+    filtered_at_least_one = False
+    temp_all_frames = []
+    filtered_frames = []
+
+    while tb is not None:
+        frame_code = tb.tb_frame.f_code
+        should_remove = any(filter_func(frame_code) for filter_func in 
filter_funcs)
+        if not should_remove:
+            filtered_at_least_one = True
+            filtered_frames.append(tb)
+        temp_all_frames.append(tb)
+        tb = tb.tb_next
+
+    if not filtered_at_least_one:
+        return orig
+
+    def _append_frame(tb, next_tb_frame):
+        return types.TracebackType(
+            tb, next_tb_frame.tb_frame, next_tb_frame.tb_lasti, 
next_tb_frame.tb_lineno
+        )
+
+    new_tb = functools.reduce(_append_frame, reversed(filtered_frames))
+
+    return new_tb
+
+
+def raise_last_ffi_error():
+    """Raise the previous error from FFI
+
+    This should be used instead of `raise get_last_ffi_error()`, as it
+    handle propagation of errors across an FFI boundary.  For example,
+    if Python passes a callback to a C++ function, and the callback
+    raises an exception, the re-thrown exception should contain the
+    full stack trace, not just the stack frames that are above the
+    outermost FFI call.
+    """
+
+    _LIB.TVMGetLastPythonError.restype = ctypes.c_void_p
+    _LIB.TVMGetLastBacktrace.restype = ctypes.c_char_p
+    py_err = _LIB.TVMGetLastPythonError()
+    if py_err is None:
+        c_err_msg = py_str(_LIB.TVMGetLastError())
+        py_err_msg, err_type = c2pyerror(c_err_msg)
+        if err_type is not None and err_type.startswith("tvm.error."):
+            err_type = err_type[10:]
+        py_err = ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
+
+    else:
+        # TVMGetLastPythonError returns a PyObject*, with NULL when
+        # there is no such value.  If we annotated the restype as
+        # ctypes.py_object, we would need to return Py_None from the
+        # C++ implementation.  This would require introducing a
+        # dependency on libpython that we want to avoid when not in a
+        # Python environment.  Therefore, casting the resulting void*
+        # pointer to PyObject* using ctypes.
+        py_err = ctypes.cast(ctypes.c_void_p(py_err), ctypes.py_object).value
+
+    tb = py_err.__traceback__
+
+    # The py_err.__traceback__ only goes from the location thrown
+    # up to the next FFI handoff.  To have the stacktrace also
+    # include the C++ side, we need to adjust the __traceback__
+    # before re-throwing.
+    backtrace = _LIB.TVMGetLastBacktrace()
+    if backtrace:
+        frames = re.split(r"\n\W+\d+:\W+", py_str(backtrace))
+        frames = frames[1:]  # Skip "Stack trace: "
+
+        for frame in frames:
+            if " at " in frame:
+                func_name, frame = frame.split(" at ", 1)
+                filename, lineno = frame.rsplit(":", 1)
+                func_name = func_name.strip()
+                filename = filename.strip()
+                lineno = int(lineno.strip())
+
+                tb = _append_traceback_frame(tb, func_name, filename, lineno)
+
+    # Remove stack frames that provide little benefit to
+    # debugging.  These are only removed from the stack frames
+    # contained within the exception we are re-raising, and not to
+    # the stack frames that it will continue to collect.
+    # Therefore, there may still be a single instance of these
+    # frames in the outermost Python-to-FFI call.
+    filter_funcs = [
+        lambda code: "tvm/_ffi/_ctypes/packed_func.py" in code.co_filename,
+        lambda code: "tvm/_ffi/base.py" in code.co_filename,
+    ]
+    tb = _filter_traceback_frames(tb, filter_funcs)
+
+    py_err = py_err.with_traceback(tb)
+
+    # The exception PyObject may contain a large amount of state,
+    # including all stack frames that may be inspected in a later
+    # PDB post-mortem.  Therefore, we must make sure to remove the
+    # underlying PyObject* from the C++ side after we retrieve it.
+    _LIB.TVMDropLastPythonError()
+
+    raise py_err
+
+
 def check_call(ret):
     """Check the return value of C API call
 
@@ -345,4 +488,4 @@ def check_call(ret):
         return value from API calls
     """
     if ret != 0:
-        raise get_last_ffi_error()
+        raise_last_ffi_error()
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 66b06e6b50..9f98977790 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -182,44 +182,78 @@ Map<String, Map<String, String>> 
PassContext::ListConfigs() {
 
 PassContext PassContext::Create() { return 
PassContext(make_object<PassContextNode>()); }
 
+namespace {
+struct ClearOnError {
+  Array<instrument::PassInstrument>* instruments{nullptr};
+
+  ~ClearOnError() {
+    if (instruments) {
+      LOG(INFO) << "Pass instrumentation enter/exti failed.";
+      LOG(INFO) << "Disabling pass instrumentation.";
+      instruments->clear();
+    }
+  }
+};
+struct ExitContextOnError {
+  std::vector<instrument::PassInstrument> successes;
+
+  ~ExitContextOnError() {
+    for (auto it = successes.rbegin(); it != successes.rend(); it++) {
+      LOG(INFO) << (*it)->name << " exiting PassContext ...";
+      (*it)->ExitPassContext();
+      LOG(INFO) << (*it)->name << " exited PassContext.";
+    }
+  }
+};
+}  // namespace
+
 void PassContext::InstrumentEnterPassContext() {
   auto pass_ctx_node = this->operator->();
   if (pass_ctx_node->instruments.defined()) {
-    Array<instrument::PassInstrument> enter_successes;
-    try {
-      for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
-        pi->EnterPassContext();
-        enter_successes.push_back(pi);
-      }
-    } catch (const Error& e) {
-      LOG(INFO) << "Pass instrumentation entering pass context failed.";
-      LOG(INFO) << "Disable pass instrumentation.";
-      pass_ctx_node->instruments.clear();
-
-      for (instrument::PassInstrument pi : enter_successes) {
-        LOG(INFO) << pi->name << " exiting PassContext ...";
-        pi->ExitPassContext();
-        LOG(INFO) << pi->name << " exited PassContext.";
-      }
-      enter_successes.clear();
-
-      throw e;
+    ClearOnError clear_context{&pass_ctx_node->instruments};
+    ExitContextOnError exit_context;
+    for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+      pi->EnterPassContext();
+      exit_context.successes.push_back(pi);
     }
+    exit_context.successes.clear();
+    clear_context.instruments = nullptr;
   }
 }
 
+namespace {
+
+struct ExitPassSuccesses {
+  ~ExitPassSuccesses() {
+    if (all_initialized) {
+      return;
+    }
+
+    LOG(INFO) << "Pass instrumentation entering pass context failed.";
+    LOG(INFO) << "Disable pass instrumentation.";
+    instruments->clear();
+
+    for (auto it = successes.rbegin(); it != successes.rend(); it++) {
+      LOG(INFO) << (*it)->name << " exiting PassContext ...";
+      (*it)->ExitPassContext();
+      LOG(INFO) << (*it)->name << " exited PassContext.";
+    }
+  }
+
+  bool all_initialized{false};
+  std::vector<instrument::PassInstrument> successes;
+  Array<instrument::PassInstrument>* instruments{nullptr};
+};
+}  // namespace
+
 void PassContext::InstrumentExitPassContext() {
   auto pass_ctx_node = this->operator->();
   if (pass_ctx_node->instruments.defined()) {
-    try {
-      for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
-        pi->ExitPassContext();
-      }
-    } catch (const Error& e) {
-      LOG(INFO) << "Pass instrumentation exiting pass context failed.";
-      pass_ctx_node->instruments.clear();
-      throw e;
+    ClearOnError clear_context{&pass_ctx_node->instruments};
+    for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+      pi->ExitPassContext();
     }
+    clear_context.instruments = nullptr;
   }
 }
 
diff --git a/src/relay/analysis/type_solver.cc 
b/src/relay/analysis/type_solver.cc
index 79b340390b..5bd5698d83 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -639,8 +639,6 @@ bool TypeSolver::Solve() {
     } catch (const CompileError& err) {
       this->Emit(Diagnostic::Error(rnode->span) << err.what());
       rnode->resolved = false;
-    } catch (const Error& e) {
-      ICHECK(false) << e.what();
     }
 
     // Mark inqueue as false after the function call
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 93ca8a924a..980447214a 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -35,6 +35,8 @@
 #include <cstdlib>
 #include <sstream>
 #include <string>
+#include <tuple>
+#include <variant>
 
 #include "object_internal.h"
 #include "runtime_base.h"
@@ -368,22 +370,102 @@ std::string NormalizeError(std::string err_msg) {
 
 using namespace tvm::runtime;
 
+struct WrappedPythonError : Error {
+  WrappedPythonError() : Error("") {}
+  explicit WrappedPythonError(WrappedPythonObject obj)
+      : Error(""), obj(std::move(obj)), 
cpp_backtrace(tvm::runtime::Backtrace()) {}
+
+  WrappedPythonObject obj;
+  std::string cpp_backtrace;
+};
+
 struct TVMRuntimeEntry {
   std::string ret_str;
-  std::string last_error;
   TVMByteArray ret_bytes;
+
+  std::variant<WrappedPythonError, InternalError, std::string> last_error;
+  std::string last_error_formatted;
 };
 
 typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
 
-const char* TVMGetLastError() { return 
TVMAPIRuntimeStore::Get()->last_error.c_str(); }
+const char* TVMGetLastError() {
+  auto* store = TVMAPIRuntimeStore::Get();
+  const auto& last_error = store->last_error;
+  if (const auto* message = std::get_if<std::string>(&last_error)) {
+    return message->c_str();
+  } else if (const auto* internal = std::get_if<InternalError>(&last_error)) {
+    // Use last_error_formatted to store the formatted error message, to avoid
+    // dangling pointer.
+    store->last_error_formatted = NormalizeError(internal->full_message());
+    return store->last_error_formatted.c_str();
+  } else {
+    return nullptr;
+  }
+}
+
+extern "C" void* TVMGetLastPythonError() {
+  auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+  if (auto* wrapped = std::get_if<WrappedPythonError>(&last_error)) {
+    return wrapped->obj.raw_pointer();
+  } else {
+    return nullptr;
+  }
+}
+
+extern "C" const char* TVMGetLastBacktrace() {
+  const auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+  if (const auto* wrapped = std::get_if<WrappedPythonError>(&last_error)) {
+    return wrapped->cpp_backtrace.data();
+  } else if (const auto* wrapped = std::get_if<InternalError>(&last_error)) {
+    return wrapped->backtrace().data();
+  } else {
+    return nullptr;
+  }
+}
+
+extern "C" void TVMDropLastPythonError() {
+  auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+  if (std::get_if<WrappedPythonError>(&last_error)) {
+    last_error = "";
+  }
+}
 
 int TVMAPIHandleException(const std::exception& e) {
-  TVMAPISetLastError(NormalizeError(e.what()).c_str());
+  auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+
+  if (const auto* wrapped = dynamic_cast<const WrappedPythonError*>(&e)) {
+    last_error = *wrapped;
+  } else if (const auto* internal = dynamic_cast<const InternalError*>(&e)) {
+    last_error = *internal;
+  } else {
+    last_error = NormalizeError(e.what());
+  }
   return -1;
 }
 
-void TVMAPISetLastError(const char* msg) { 
TVMAPIRuntimeStore::Get()->last_error = msg; }
+extern "C" void TVMAPISetLastPythonError(void* obj) {
+  auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+  last_error = WrappedPythonError(WrappedPythonObject(obj));
+}
+
+void ThrowLastError() {
+  auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+  if (auto* wrapped = std::get_if<WrappedPythonError>(&last_error)) {
+    WrappedPythonError wrapped_err;
+    std::swap(wrapped_err, *wrapped);
+    throw wrapped_err;
+  } else if (auto* internal = std::get_if<InternalError>(&last_error)) {
+    throw *internal;
+  } else if (auto* message = std::get_if<std::string>(&last_error)) {
+    throw tvm::Error(NormalizeError(*message) + tvm::runtime::Backtrace());
+  }
+}
+
+void TVMAPISetLastError(const char* msg) {
+  auto& last_error = TVMAPIRuntimeStore::Get()->last_error;
+  last_error = msg;
+}
 
 int TVMModLoadFromFile(const char* file_name, const char* format, 
TVMModuleHandle* out) {
   API_BEGIN();
@@ -515,7 +597,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* 
resource_handle, TVMPacked
       int ret = func(const_cast<TVMValue*>(args.values), 
const_cast<int*>(args.type_codes),
                      args.num_args, rv, resource_handle);
       if (ret != 0) {
-        throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace());
+        ThrowLastError();
       }
     });
     TVMValue val;
@@ -531,7 +613,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* 
resource_handle, TVMPacked
       int ret = func(const_cast<TVMValue*>(args.values), 
const_cast<int*>(args.type_codes),
                      args.num_args, rv, rpack.get());
       if (ret != 0) {
-        throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace());
+        ThrowLastError();
       }
     });
     TVMValue val;
diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc
index 04b25f764c..844a8bcf1c 100644
--- a/src/runtime/logging.cc
+++ b/src/runtime/logging.cc
@@ -94,64 +94,115 @@ void BacktraceSyminfoCallback(void* data, uintptr_t pc, 
const char* symname, uin
 
 int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int 
lineno,
                           const char* symbol) {
-  if (filename != nullptr) {
-    if (strstr(filename, "include/tvm/runtime/packed_func.h") != nullptr ||
-        strstr(filename, "include/tvm/runtime/registry.h") != nullptr ||
-        strstr(filename, "include/tvm/node/functor.h") != nullptr ||
-        strstr(filename, "include/tvm/relax/expr_functor.h") != nullptr ||
-        strstr(filename, "include/tvm/tir/stmt_functor.h") != nullptr ||
-        strstr(filename, "include/tvm/tir/expr_functor.h") != nullptr ||
-        strstr(filename, "include/tvm/node/reflection.h") != nullptr ||
-        strstr(filename, "src/node/structural_equal.cc") != nullptr ||
-        strstr(filename, "src/ir/transform.cc") != nullptr ||
-        strstr(filename, "src/tir/ir/stmt_functor.cc") != nullptr ||
-        strstr(filename, "src/tir/ir/expr_functor.cc") != nullptr ||
-        strstr(filename, "src/relax/ir/expr_functor.cc") != nullptr ||
-        strstr(filename, "src/relax/ir/py_expr_functor.cc") != nullptr ||
-        strstr(filename, "src/runtime/c_runtime_api.cc") != nullptr ||
-        strstr(filename, "/python-") != nullptr ||  //
-        strstr(filename, "include/c++/") != nullptr) {
-      return 0;
-    }
-  }
-  if (symbol != nullptr) {
-    if (strstr(symbol, "__libc_") != nullptr) {
-      return 0;
-    }
-  }
   auto stack_trace = reinterpret_cast<BacktraceInfo*>(data);
-  std::stringstream s;
 
   std::unique_ptr<std::string> symbol_str = 
std::make_unique<std::string>("<unknown>");
-  if (symbol != nullptr) {
+  if (symbol) {
     *symbol_str = DemangleName(symbol);
   } else {
     // see if syminfo gives anything
     backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, 
BacktraceErrorCallback,
                       symbol_str.get());
   }
-  if (filename == nullptr && strstr(symbol_str.get()->data(), "ffi_call_")) {
+  symbol = symbol_str->data();
+
+  // TVMFuncCall denotes the API boundary so we stop there. Exceptions
+  // should be caught there.  This is before any frame suppressions,
+  // as it would otherwise be suppressed.
+  bool should_stop_collecting =
+      (*symbol_str == "TVMFuncCall" || stack_trace->lines.size() >= 
stack_trace->max_size);
+  if (should_stop_collecting) {
+    return 1;
+  }
+
+  // Exclude frames that contain little useful information for most
+  // debugging purposes
+  bool should_exclude = [&]() -> bool {
+    if (filename) {
+      // Stack frames for TVM FFI
+      if (strstr(filename, "include/tvm/runtime/packed_func.h") ||
+          strstr(filename, "include/tvm/runtime/registry.h") ||
+          strstr(filename, "src/runtime/c_runtime_api.cc")) {
+        return true;
+      }
+      // Stack frames for nested tree recursion.
+      // tir/ir/stmt_functor.cc and tir/ir/expr_functor.cc define
+      // Expr/Stmt Visitor/Mutator, which should be suppressed, but
+      // also Substitute which should not be suppressed.  Therefore,
+      // they are suppressed based on the symbol name.
+      if (strstr(filename, "include/tvm/node/functor.h") ||        //
+          strstr(filename, "include/tvm/relax/expr_functor.h") ||  //
+          strstr(filename, "include/tvm/tir/stmt_functor.h") ||    //
+          strstr(filename, "include/tvm/tir/expr_functor.h") ||    //
+          strstr(filename, "include/tvm/node/reflection.h") ||     //
+          strstr(filename, "src/node/structural_equal.cc") ||      //
+          strstr(filename, "src/ir/transform.cc") ||               //
+          strstr(filename, "src/relax/ir/expr_functor.cc") ||      //
+          strstr(filename, "src/relax/ir/py_expr_functor.cc")) {
+        return true;
+      }
+      // Python interpreter stack frames
+      if (strstr(filename, "/python-") || strstr(filename, "/Python/ceval.c") 
||
+          strstr(filename, "/Modules/_ctypes")) {
+        return true;
+      }
+      // C++ stdlib frames
+      if (strstr(filename, "include/c++/")) {
+        return true;
+      }
+    }
+    if (symbol) {
+      // C++ stdlib frames
+      if (strstr(symbol, "__libc_")) {
+        return true;
+      }
+      // Stack frames for nested tree visiting
+      if (strstr(symbol, "tvm::tir::StmtMutator::VisitStmt_") ||
+          strstr(symbol, "tvm::tir::ExprMutator::VisitExpr_") ||
+          strstr(symbol, "tvm::tir::IRTransformer::VisitExpr") ||
+          strstr(symbol, "tvm::tir::IRTransformer::VisitStmt") ||
+          strstr(symbol, "tvm::tir::IRTransformer::BaseVisitExpr") ||
+          strstr(symbol, "tvm::tir::IRTransformer::BaseVisitStmt")) {
+        return true;
+      }
+      // Python interpreter stack frames
+      if (strstr(symbol, "_Py") == symbol || strstr(symbol, "PyObject")) {
+        return true;
+      }
+    }
+
+    // libffi.so stack frames.  These may also show up as numeric
+    // addresses with no symbol name.  This could be improved in the
+    // future by using dladdr() to check whether an address is contained
+    // in libffi.so
+    if (filename == nullptr && strstr(symbol, "ffi_call_")) {
+      return true;
+    }
+
+    // Skip tvm::backtrace and tvm::LogFatal::~LogFatal at the beginning
+    // of the trace as they don't add anything useful to the backtrace.
+    if (stack_trace->lines.size() == 0 && (strstr(symbol, 
"tvm::runtime::Backtrace") ||
+                                           strstr(symbol, 
"tvm::runtime::detail::LogFatal"))) {
+      return true;
+    }
+
+    return false;
+  }();
+  if (should_exclude) {
     return 0;
   }
-  s << *symbol_str;
 
-  if (filename != nullptr) {
-    s << std::endl << "        at " << filename;
+  std::stringstream frame_str;
+  frame_str << *symbol_str;
+
+  if (filename) {
+    frame_str << std::endl << "        at " << filename;
     if (lineno != 0) {
-      s << ":" << lineno;
+      frame_str << ":" << lineno;
     }
   }
-  // Skip tvm::backtrace and tvm::LogFatal::~LogFatal at the beginning of the 
trace as they don't
-  // add anything useful to the backtrace.
-  if (!(stack_trace->lines.size() == 0 &&
-        (symbol_str->find("tvm::runtime::Backtrace", 0) == 0 ||
-         symbol_str->find("tvm::runtime::detail::LogFatal", 0) == 0))) {
-    stack_trace->lines.push_back(s.str());
-  }
-  // TVMFuncCall denotes the API boundary so we stop there. Exceptions should 
be caught there.
-  if (*symbol_str == "TVMFuncCall" || stack_trace->lines.size() >= 
stack_trace->max_size) {
-    return 1;
-  }
+  stack_trace->lines.push_back(frame_str.str());
+
   return 0;
 }
 
diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc
index 84586ff630..0db8786145 100644
--- a/src/runtime/registry.cc
+++ b/src/runtime/registry.cc
@@ -128,13 +128,26 @@ class EnvCAPIRegistry {
    */
   typedef int (*F_PyErr_CheckSignals)();
 
-  // NOTE: the following function are only registered
-  // in a python environment.
+  /*! \brief Callback to increment/decrement the python ref count */
+  typedef void (*F_Py_IncDefRef)(void*);
+
+  // NOTE: the following functions are only registered in a python
+  // environment.
   /*!
    * \brief PyErr_CheckSignal function
    */
   F_PyErr_CheckSignals pyerr_check_signals = nullptr;
 
+  /*!
+   * \brief Py_IncRef function
+   */
+  F_Py_IncDefRef py_inc_ref = nullptr;
+
+  /*!
+   * \brief Py_IncRef function
+   */
+  F_Py_IncDefRef py_dec_ref = nullptr;
+
   static EnvCAPIRegistry* Global() {
     static EnvCAPIRegistry* inst = new EnvCAPIRegistry();
     return inst;
@@ -144,6 +157,10 @@ class EnvCAPIRegistry {
   void Register(const String& symbol_name, void* fptr) {
     if (symbol_name == "PyErr_CheckSignals") {
       Update(symbol_name, &pyerr_check_signals, fptr);
+    } else if (symbol_name == "Py_IncRef") {
+      Update(symbol_name, &py_inc_ref, fptr);
+    } else if (symbol_name == "Py_DecRef") {
+      Update(symbol_name, &py_dec_ref, fptr);
     } else {
       LOG(FATAL) << "Unknown env API " << symbol_name;
     }
@@ -159,6 +176,18 @@ class EnvCAPIRegistry {
     }
   }
 
+  void IncRef(void* python_obj) {
+    ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through 
EnvCAPIRegistry, "
+                       << "but Py_IncRef wasn't registered";
+    (*py_inc_ref)(python_obj);
+  }
+
+  void DecRef(void* python_obj) {
+    ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through 
EnvCAPIRegistry, "
+                       << "but Py_IncRef wasn't registered";
+    (*py_inc_ref)(python_obj);
+  }
+
  private:
   // update the internal API table
   template <typename FType>
@@ -173,6 +202,35 @@ class EnvCAPIRegistry {
 
 void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); }
 
+WrappedPythonObject::WrappedPythonObject(void* python_obj) : 
python_obj_(python_obj) {
+  if (python_obj_) {
+    EnvCAPIRegistry::Global()->IncRef(python_obj_);
+  }
+}
+
+WrappedPythonObject::~WrappedPythonObject() {
+  if (python_obj_) {
+    EnvCAPIRegistry::Global()->DecRef(python_obj_);
+  }
+}
+
+WrappedPythonObject::WrappedPythonObject(WrappedPythonObject&& other) : 
python_obj_(nullptr) {
+  std::swap(python_obj_, other.python_obj_);
+}
+WrappedPythonObject& WrappedPythonObject::operator=(WrappedPythonObject&& 
other) {
+  std::swap(python_obj_, other.python_obj_);
+  return *this;
+}
+
+WrappedPythonObject::WrappedPythonObject(const WrappedPythonObject& other)
+    : WrappedPythonObject(other.python_obj_) {}
+WrappedPythonObject& WrappedPythonObject::operator=(const WrappedPythonObject& 
other) {
+  return *this = WrappedPythonObject(other);
+}
+WrappedPythonObject& WrappedPythonObject::operator=(std::nullptr_t) {
+  return *this = WrappedPythonObject(nullptr);
+}
+
 }  // namespace runtime
 }  // namespace tvm
 
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 6e7dec4cb7..e00b9b8d05 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -63,6 +63,18 @@ 
TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMR
   *ret = runtime::TypedPackedFunc<void()>([pf]() { pf(); });
 });
 
+TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      PackedFunc pf = args[0];
+      auto result = runtime::TypedPackedFunc<void()>([pf]() {
+        try {
+          pf();
+        } catch (std::exception& err) {
+        }
+      });
+      *ret = result;
+    });
+
 TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
     .set_body([](TVMArgs args, TVMRetValue* ret) {
       std::string msg = args[0];
diff --git a/tests/python/relay/test_pass_instrument.py 
b/tests/python/relay/test_pass_instrument.py
index 83ddc4ef37..455cf20b5d 100644
--- a/tests/python/relay/test_pass_instrument.py
+++ b/tests/python/relay/test_pass_instrument.py
@@ -226,7 +226,7 @@ def test_enter_pass_ctx_exception():
             raise RuntimeError("Just a dummy error")
 
     pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), 
PIBroken("%2"), PI("%3")])
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         with pass_ctx:
             pass
         assert "Just a dummy error" in str(cm.execption)
@@ -246,7 +246,7 @@ def test_enter_pass_ctx_exception_global():
             raise RuntimeError("Just a dummy error")
 
     cur_pass_ctx = tvm.transform.PassContext.current()
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         cur_pass_ctx.override_instruments([PIBroken()])
         assert "Just a dummy error" in str(cm.exception)
     assert not cur_pass_ctx.instruments
@@ -273,7 +273,7 @@ def test_exit_pass_ctx_exception():
             raise RuntimeError("Just a dummy error")
 
     pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), 
PIBroken("%2"), PI("%3")])
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         with pass_ctx:
             pass
         assert "Just a dummy error" in str(cm.exception)
@@ -293,7 +293,7 @@ def test_exit_pass_ctx_exception_global():
             raise RuntimeError("Just a dummy error")
 
     cur_pass_ctx = tvm.transform.PassContext.current()
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         cur_pass_ctx.override_instruments([PIBroken()])
         cur_pass_ctx.override_instruments([PIBroken()])
         assert "Just a dummy error" in str(cm.exception)
@@ -328,7 +328,7 @@ def test_pass_exception():
         return mod
 
     mod = get_test_model()
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         with tvm.transform.PassContext(instruments=[PI()]):
             mod = transform(mod)
         assert "Just a dummy error" in str(cm.exception)
@@ -373,7 +373,7 @@ def test_should_run_exception():
         return mod
 
     mod = get_test_model()
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]):
             mod = transform(mod)
         assert "Just a dummy error" in str(cm.exception)
@@ -418,7 +418,7 @@ def test_run_before_exception():
         return mod
 
     mod = get_test_model()
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]):
             mod = transform(mod)
         assert "Just a dummy error" in str(cm.exception)
@@ -467,7 +467,7 @@ def test_run_after_exception():
     x, y = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xy"]
     mod = tvm.IRModule.from_expr(tvm.relay.add(x, y))
 
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(RuntimeError) as cm:
         with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]):
             mod = transform(mod)
         assert "Just a dummy error" in str(cm.exception)
diff --git a/tests/python/relay/test_type_infer.py 
b/tests/python/relay/test_type_infer.py
index 7fbb656b36..ec88143db6 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -541,7 +541,7 @@ def test_custom_op_rel_infer_exception():
     t2 = sb.let("t2", relay.add(t1, x))
     sb.ret(t2)
     f = relay.Function([x], sb.get())
-    with pytest.raises(tvm.error.TVMError) as cm:
+    with pytest.raises(AssertionError) as cm:
         fchecked = infer_expr(f)
         assert "type relation arg number mismatch" in str(cm.execption)
 
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py
index 2bfa3070d1..7222c4d649 100644
--- 
a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py
+++ 
b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py
@@ -59,7 +59,7 @@ def test_custom_rule():
                 max_trials_global=10,
                 space=space_gen,
             )
-    assert "ValueError: Intended for meta_schedule.cpu.test_apply_custom_rule" 
in str(e_info.value)
+    assert "Intended for meta_schedule.cpu.test_apply_custom_rule" in 
str(e_info.value)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_runtime_error.py 
b/tests/python/unittest/test_runtime_error.py
index 3d7a218099..efb373ac87 100644
--- a/tests/python/unittest/test_runtime_error.py
+++ b/tests/python/unittest/test_runtime_error.py
@@ -15,12 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test runtime error handling"""
+
+import functools
+import platform
+import subprocess
+import traceback
+
+import pytest
+
 import tvm
-from tvm import te
 import tvm.testing
 
 
-def test_op_translation():
+def test_op_translation_to_not_implemented():
     ferror = tvm.testing.test_raise_error_callback("OpNotImplemented: myop")
     try:
         ferror()
@@ -30,6 +37,8 @@ def test_op_translation():
         assert isinstance(e, NotImplementedError)
         assert msg.find("ffi_testing.cc") != -1
 
+
+def test_op_translation_to_internal_error():
     fchk_eq = tvm.testing.test_check_eq_callback("InternalError: myop")
     try:
         fchk_eq(0, 1)
@@ -38,6 +47,8 @@ def test_op_translation():
         msg = str(e)
         assert msg.find("ffi_testing.cc") != -1
 
+
+def test_op_translation_to_value_error():
     try:
         tvm.testing.ErrorTest(0, 1)
         assert False
@@ -47,6 +58,18 @@ def test_op_translation():
 
 
 def test_deep_callback():
+    """Propagate python errors through API calls
+
+    If a Python exception is raised, and that exception is caught in
+    Python, the original exception should be propagated so that the
+    traceback contains all intermediate python frames.
+
+    Stack
+    - test_deep_callback
+    - test
+
+    """
+
     def error_callback():
         raise ValueError("callback error")
 
@@ -65,14 +88,73 @@ def test_deep_callback():
     try:
         wrap3()
         assert False
-    except ValueError as e:
-        msg = str(e)
-        idx2 = msg.find("in flevel2")
-        idx3 = msg.find("in flevel3")
-        assert idx2 != -1 and idx3 != -1
-        assert idx2 > idx3
+    except ValueError as err:
+        frames = traceback.extract_tb(err.__traceback__)
+
+    local_frames = [frame.name for frame in frames if frame.filename == 
__file__]
+    assert local_frames == ["test_deep_callback", "flevel3", "flevel2", 
"error_callback"]
+
+
[email protected]_cache()
+def _has_debug_symbols():
+    lib = tvm._ffi.base._LIB
+    headers = subprocess.check_output(["objdump", "--section-headers", 
lib._name], encoding="utf-8")
+    return ".debug" in headers
+
+
[email protected](
+    not _has_debug_symbols() or platform.machine != "x86_64",
+    reason="C++ stack frames require debug symbols, only implemented for x86",
+)
+def test_cpp_frames_in_stack_trace_from_python_error():
+    """A python exception crossing C++ boundaries should have C++ stack 
frames"""
+
+    def error_callback():
+        raise ValueError("callback error")
+
+    wrapped = tvm.testing.test_wrap_callback(error_callback)
+
+    try:
+        wrapped()
+        assert False
+    except ValueError as err:
+        frames = traceback.extract_tb(err.__traceback__)
+
+        cpp_frames = [
+            frame
+            for frame in frames
+            if frame.filename.endswith(".cc") or frame.filename.endswith(".c")
+        ]
+        assert len(cpp_frames) >= 1, (
+            f"Traceback through files '{[frame.filename for frame in frames]}'"
+            f" expected to contain C/C++ frames, "
+            f" but instead caught exception {err}"
+        )
+
+
[email protected](
+    not _has_debug_symbols() or platform.machine != "x86_64",
+    reason="C++ stack frames require debug symbols, only implemented for x86",
+)
+def test_stack_trace_from_cpp_error():
+    """A python exception originating in C++ should have C++ stack frames"""
+    try:
+        tvm.testing.ErrorTest(0, 1)
+        assert False
+    except ValueError as err:
+        frames = traceback.extract_tb(err.__traceback__)
+
+        cpp_frames = [
+            frame
+            for frame in frames
+            if frame.filename.endswith(".cc") or frame.filename.endswith(".c")
+        ]
+        assert len(cpp_frames) >= 1, (
+            f"Traceback through files '{[frame.filename for frame in frames]}'"
+            f" expected to contain C/C++ frames, "
+            f" but instead caught exception {err}"
+        )
 
 
 if __name__ == "__main__":
-    test_op_translation()
-    test_deep_callback()
+    tvm.testing.main()


Reply via email to