This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e1bd421  [FIX] Fix the error propagation in the case of tensor 
arguments (#409)
e1bd421 is described below

commit e1bd42189949b360753f50784ceb5e64cb08254f
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jan 13 12:37:38 2026 -0500

    [FIX] Fix the error propagation in the case of tensor arguments (#409)
    
    This PR fixes error propagation in the case of tensor arguments. The bug
    was previously hidden and revealed after a fix landed in 0.1.8, so it
    does not impact previous versions. Added a regression test to cover this
    case.
---
 include/tvm/ffi/c_api.h                        |  2 +-
 python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 10 ++++++++--
 src/ffi/extra/env_context.cc                   |  3 ++-
 tests/python/test_tensor.py                    | 16 +++++++++++++++-
 4 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index a5a581a..4da0d4d 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -62,7 +62,7 @@
 /*! \brief TVM FFI minor version. */
 #define TVM_FFI_VERSION_MINOR 1
 /*! \brief TVM FFI patch version. */
-#define TVM_FFI_VERSION_PATCH 9
+#define TVM_FFI_VERSION_PATCH 8
 // NOLINTEND(modernize-macro-to-enum)
 
 #ifdef __cplusplus
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 93d6540..666bb6e 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -321,10 +321,16 @@ class TVMFFIPyCallManager {
           return -1;
         }
       }
+
       if (ctx.dlpack_c_exchange_api != nullptr &&
           prev_tensor_allocator != 
ctx.dlpack_c_exchange_api->managed_tensor_allocator) {
-        c_api_ret_code[0] =
-            TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, 
nullptr);
+        // note: we cannot set the error value to c_api_ret_code[0] here 
because it
+        // will be overwritten by the error value from the function call
+        if (TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, 
nullptr) != 0) {
+          PyErr_SetString(PyExc_RuntimeError, "Failed to recover DLPack 
managed tensor allocator");
+          return -1;
+        }
+        // return error after
         if (c_api_ret_code[0] != 0) return 0;
       }
       if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api 
!= nullptr) {
diff --git a/src/ffi/extra/env_context.cc b/src/ffi/extra/env_context.cc
index 9b2fb25..95045d4 100644
--- a/src/ffi/extra/env_context.cc
+++ b/src/ffi/extra/env_context.cc
@@ -66,7 +66,8 @@ class EnvContext {
                                        int write_to_global_context,
                                        DLPackManagedTensorAllocator* 
opt_out_original_allocator) {
     if (opt_out_original_allocator != nullptr) {
-      *opt_out_original_allocator = GetDLPackManagedTensorAllocator();
+      // only returns the cached local allocator and ignore global allocator
+      *opt_out_original_allocator = dlpack_allocator_;
     }
     if (write_to_global_context != 0) {
       GlobalTensorAllocator() = allocator;
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index d091d85..0c45654 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 from types import ModuleType
-from typing import Any, NamedTuple
+from typing import Any, NamedTuple, NoReturn
 
 import numpy.typing as npt
 import pytest
@@ -78,6 +78,20 @@ def test_tensor_auto_dlpack() -> None:
     np.testing.assert_equal(y.numpy(), x.numpy())
 
 
[email protected](torch is None, reason="Fast torch dlpack importer is not 
enabled")
+def test_tensor_auto_dlpack_with_error() -> None:
+    assert torch is not None
+    x = torch.arange(128)
+
+    def raise_torch_error(x: Any) -> NoReturn:
+        raise ValueError("error XYZ")
+
+    f = tvm_ffi.convert(raise_torch_error)
+    with pytest.raises(ValueError):
+        # pass in torch argment to trigger the error in set allocator path
+        f(x)
+
+
 def test_tensor_class_override() -> None:
     class MyTensor(tvm_ffi.Tensor):
         pass

Reply via email to