This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new e1bd421 [FIX] Fix the error propagation in the case of tensor
arguments (#409)
e1bd421 is described below
commit e1bd42189949b360753f50784ceb5e64cb08254f
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jan 13 12:37:38 2026 -0500
[FIX] Fix the error propagation in the case of tensor arguments (#409)
This PR fixes error propagation in the case of tensor arguments. The bug
was previously hidden and revealed after a fix landed in 0.1.8, so it
does not impact previous versions. Added a regression test to cover this
case.
---
include/tvm/ffi/c_api.h | 2 +-
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 10 ++++++++--
src/ffi/extra/env_context.cc | 3 ++-
tests/python/test_tensor.py | 16 +++++++++++++++-
4 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index a5a581a..4da0d4d 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -62,7 +62,7 @@
/*! \brief TVM FFI minor version. */
#define TVM_FFI_VERSION_MINOR 1
/*! \brief TVM FFI patch version. */
-#define TVM_FFI_VERSION_PATCH 9
+#define TVM_FFI_VERSION_PATCH 8
// NOLINTEND(modernize-macro-to-enum)
#ifdef __cplusplus
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 93d6540..666bb6e 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -321,10 +321,16 @@ class TVMFFIPyCallManager {
return -1;
}
}
+
if (ctx.dlpack_c_exchange_api != nullptr &&
prev_tensor_allocator !=
ctx.dlpack_c_exchange_api->managed_tensor_allocator) {
- c_api_ret_code[0] =
- TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0,
nullptr);
+ // note: we cannot set the error value to c_api_ret_code[0] here
because it
+ // will be overwritten by the error value from the function call
+ if (TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0,
nullptr) != 0) {
+ PyErr_SetString(PyExc_RuntimeError, "Failed to recover DLPack
managed tensor allocator");
+ return -1;
+ }
+ // return error after
if (c_api_ret_code[0] != 0) return 0;
}
if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api
!= nullptr) {
diff --git a/src/ffi/extra/env_context.cc b/src/ffi/extra/env_context.cc
index 9b2fb25..95045d4 100644
--- a/src/ffi/extra/env_context.cc
+++ b/src/ffi/extra/env_context.cc
@@ -66,7 +66,8 @@ class EnvContext {
int write_to_global_context,
DLPackManagedTensorAllocator*
opt_out_original_allocator) {
if (opt_out_original_allocator != nullptr) {
- *opt_out_original_allocator = GetDLPackManagedTensorAllocator();
+ // only returns the cached local allocator and ignore global allocator
+ *opt_out_original_allocator = dlpack_allocator_;
}
if (write_to_global_context != 0) {
GlobalTensorAllocator() = allocator;
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index d091d85..0c45654 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -18,7 +18,7 @@
from __future__ import annotations
from types import ModuleType
-from typing import Any, NamedTuple
+from typing import Any, NamedTuple, NoReturn
import numpy.typing as npt
import pytest
@@ -78,6 +78,20 @@ def test_tensor_auto_dlpack() -> None:
np.testing.assert_equal(y.numpy(), x.numpy())
[email protected](torch is None, reason="Fast torch dlpack importer is not
enabled")
+def test_tensor_auto_dlpack_with_error() -> None:
+ assert torch is not None
+ x = torch.arange(128)
+
+ def raise_torch_error(x: Any) -> NoReturn:
+ raise ValueError("error XYZ")
+
+ f = tvm_ffi.convert(raise_torch_error)
+ with pytest.raises(ValueError):
+ # pass in torch argment to trigger the error in set allocator path
+ f(x)
+
+
def test_tensor_class_override() -> None:
class MyTensor(tvm_ffi.Tensor):
pass