This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new f679fe5 [ABI] Refactor the TVMFFIEnvTensorAllocator to align with
DLPack (#131)
f679fe5 is described below
commit f679fe54cf2e78ac175d4644a19b8713a9576c62
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Oct 15 13:46:26 2025 -0400
[ABI] Refactor the TVMFFIEnvTensorAllocator to align with DLPack (#131)
This PR refactors the TVMFFIEnvTensorAllocator to align with the
finalized naming of DPack standard.
- TVMFFIEnvSetTensorAllocator ->
TVMFFIEnvSetDLPackManagedTensorAllocator
- TVMFFIEnvGetTensorAllocator ->
TVMFFIEnvGetDLPackManagedTensorAllocator
We also introduced TVMFFIEnvTensorAlloc to directly allocate a
ffi::Tensor from the given prototype which can be used by DSL compilers.
Note that the new TVMFFIEnvTensorAlloc makes metadata allocation in
libtvm_ffi so it won't suffer from the module unloading order problem.
We removed Tensor::FromDLPackAlloc in favor of Tensor::FromEnvAlloc that
makes use of TVMFFIEnvTensorAlloc.
Hopefully this step allows us to stablize the tensor alloc api to align
with dlpack before we freeze.
---
include/tvm/ffi/container/tensor.h | 61 +++++-------------
include/tvm/ffi/extra/c_env_api.h | 24 +++++--
pyproject.toml | 9 ++-
python/tvm_ffi/__init__.py | 2 +-
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 5 +-
rust/tvm-ffi-sys/src/c_env_api.rs | 4 +-
src/ffi/extra/env_context.cc | 43 +++++++++++--
tests/cpp/extra/test_c_env_api.cc | 86 ++++++++++++++++++++++++++
tests/cpp/test_tensor.cc | 26 +++-----
tests/python/test_load_inline.py | 8 +--
10 files changed, 188 insertions(+), 80 deletions(-)
diff --git a/include/tvm/ffi/container/tensor.h
b/include/tvm/ffi/container/tensor.h
index 00cc402..3164267 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -359,66 +359,39 @@ class Tensor : public ObjectRef {
std::forward<ExtraArgs>(extra_args)...));
}
/*!
- * \brief Create a Tensor from a DLPackManagedTensorAllocator
+ * \brief Create a Tensor from the TVMFFIEnvTensorAlloc API
*
- * This function can be used together with TVMFFIEnvSetTensorAllocator
- * in the extra/c_env_api.h to create Tensor from the thread-local
- * environment allocator.
+ * This function can be used together with
TVMFFIEnvSetDLPackManagedTensorAllocator
+ * in the extra/c_env_api.h to create a Tensor from the thread-local
environment allocator.
+ * We explicitly pass TVMFFIEnvTensorAlloc to maintain explicit dependency
on extra/c_env_api.h
*
* \code
*
- * ffi::Tensor tensor = ffi::Tensor::FromDLPackAlloc(
- * TVMFFIEnvGetTensorAllocator(), shape, dtype, device
+ * ffi::Tensor tensor = ffi::Tensor::FromEnvAlloc(
+ * TVMFFIEnvTensorAlloc, shape, dtype, device
* );
+ *
* \endcode
*
- * \param allocator The DLPack allocator.
+ * \param env_alloc TVMFFIEnvTensorAlloc function pointer.
* \param shape The shape of the Tensor.
* \param dtype The data type of the Tensor.
* \param device The device of the Tensor.
* \return The created Tensor.
+ *
+ * \sa TVMFFIEnvTensorAlloc
*/
- static Tensor FromDLPackAlloc(DLPackManagedTensorAllocator allocator,
ffi::ShapeView shape,
- DLDataType dtype, DLDevice device) {
- if (allocator == nullptr) {
- TVM_FFI_THROW(RuntimeError)
- << "FromDLPackAlloc: allocator is nullptr, "
- << "likely because TVMFFIEnvSetTensorAllocator has not been called.";
- }
- DLTensor prototype;
+ static Tensor FromEnvAlloc(int (*env_alloc)(DLTensor*, TVMFFIObjectHandle*),
ffi::ShapeView shape,
+ DLDataType dtype, DLDevice device) {
+ TVMFFIObjectHandle out;
+ DLTensor prototype{};
prototype.device = device;
prototype.dtype = dtype;
prototype.shape = const_cast<int64_t*>(shape.data());
prototype.ndim = static_cast<int>(shape.size());
- prototype.strides = nullptr;
- prototype.byte_offset = 0;
- prototype.data = nullptr;
- DLManagedTensorVersioned* tensor = nullptr;
- // error context to be used to propagate error
- struct ErrorContext {
- std::string kind;
- std::string message;
- static void SetError(void* error_ctx, const char* kind, const char*
message) {
- ErrorContext* error_context = static_cast<ErrorContext*>(error_ctx);
- error_context->kind = kind;
- error_context->message = message;
- }
- };
- ErrorContext error_context;
- int ret = (*allocator)(&prototype, &tensor, &error_context,
ErrorContext::SetError);
- if (ret != 0) {
- throw ffi::Error(error_context.kind, error_context.message,
- TVMFFIBacktrace(__FILE__, __LINE__, __func__, 0));
- }
- if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
- return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
- tensor, /*extra_strides_at_tail=*/false));
- } else {
- return Tensor(
-
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
- int64_t>(tensor->dl_tensor.ndim, tensor,
- /*extra_strides_at_tail=*/true));
- }
+ TVM_FFI_CHECK_SAFE_CALL(env_alloc(&prototype, &out));
+ return Tensor(
+
details::ObjectUnsafe::ObjectPtrFromOwned<TensorObj>(static_cast<TVMFFIObject*>(out)));
}
/*!
* \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API.
diff --git a/include/tvm/ffi/extra/c_env_api.h
b/include/tvm/ffi/extra/c_env_api.h
index 8276825..6192db5 100644
--- a/include/tvm/ffi/extra/c_env_api.h
+++ b/include/tvm/ffi/extra/c_env_api.h
@@ -63,26 +63,40 @@ TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type,
int32_t device_id,
TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t
device_id);
/*!
- * \brief FFI function to set the current DLPack allocator in
thread-local(TLS) context
+ * \brief Set the current DLPackManagedTensorAllocator in thread-local(TLS)
context
*
* \param allocator The allocator to set.
* \param write_to_global_context Whether to also set the allocator to the
global context.
* \param opt_out_original_allocator Output original TLS allocator if the
address is not nullptr.
* \return 0 when success, nonzero when failure happens
*/
-TVM_FFI_DLL int TVMFFIEnvSetTensorAllocator(
+TVM_FFI_DLL int TVMFFIEnvSetDLPackManagedTensorAllocator(
DLPackManagedTensorAllocator allocator, int write_to_global_context,
DLPackManagedTensorAllocator* opt_out_original_allocator);
/*!
- * \brief FFI function get the current DLPack allocator stored in context.
+ * \brief FFI function get the current DLPackManagedTensorAllocator stored in
context.
*
* This function first queries the global context, and if not found,
* queries the thread-local context.
*
- * \return The current DLPack allocator.
+ * \return The current setted DLPackManagedTensorAllocator
*/
-TVM_FFI_DLL DLPackManagedTensorAllocator TVMFFIEnvGetTensorAllocator();
+TVM_FFI_DLL DLPackManagedTensorAllocator
TVMFFIEnvGetDLPackManagedTensorAllocator();
+
+/*!
+ * \brief Allocate a tensor from the allocator set in thread-local(TLS)
context.
+ *
+ * This function redirects to one of environment allocator. As of now, we only
+ * support the DLPackManagedTensorAllocator set in thread-local(TLS) context.
+ *
+ * \param prototype The prototype DLTensor, only the dtype, ndim, shape,
+ * and device fields are used, other fields are ignored.
+ * \param out The output tensor in kTVMFFITensor type.
+ * \return 0 when success, nonzero when failure happens
+ * \sa TVMFFIEnvSetDLPackManagedTensorAllocator
+ */
+TVM_FFI_DLL int TVMFFIEnvTensorAlloc(DLTensor* prototype, TVMFFIObjectHandle*
out);
/*!
* \brief Check if there are any signals raised in the surrounding env.
diff --git a/pyproject.toml b/pyproject.toml
index 5363c87..732bc2c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b18"
+version = "0.1.0b19"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
@@ -43,7 +43,12 @@ GitHub = "https://github.com/apache/tvm-ffi"
torch = ["torch", "setuptools", "ninja"]
cpp = ["ninja"]
# note pytorch does not yet ship with 3.14t
-test = ["pytest", "numpy", "ninja", "torch; python_version < '3.14'"]
+test = [
+ "pytest",
+ "numpy",
+ "ninja",
+ "torch; python_version < '3.14' and sys_platform != 'win32'",
+]
[dependency-groups]
docs = [
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index f6d9045..e2bf619 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
"""TVM FFI Python package."""
# version
-__version__ = "0.1.0b18"
+__version__ = "0.1.0b19"
# order matters here so we need to skip isort here
# isort: skip_file
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index bfd5d60..3ab8665 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -271,7 +271,7 @@ class TVMFFIPyCallManager {
}
if (ctx.c_dlpack_exchange_api != nullptr &&
ctx.c_dlpack_exchange_api->managed_tensor_allocator != nullptr) {
- c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(
+ c_api_ret_code[0] = TVMFFIEnvSetDLPackManagedTensorAllocator(
ctx.c_dlpack_exchange_api->managed_tensor_allocator, 0,
&prev_tensor_allocator);
if (c_api_ret_code[0] != 0) return 0;
}
@@ -295,7 +295,8 @@ class TVMFFIPyCallManager {
}
if (ctx.c_dlpack_exchange_api != nullptr &&
prev_tensor_allocator !=
ctx.c_dlpack_exchange_api->managed_tensor_allocator) {
- c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator,
0, nullptr);
+ c_api_ret_code[0] =
+ TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0,
nullptr);
if (c_api_ret_code[0] != 0) return 0;
}
if (optional_out_ctx_dlpack_api != nullptr && ctx.c_dlpack_exchange_api
!= nullptr) {
diff --git a/rust/tvm-ffi-sys/src/c_env_api.rs
b/rust/tvm-ffi-sys/src/c_env_api.rs
index e0b9306..94feec1 100644
--- a/rust/tvm-ffi-sys/src/c_env_api.rs
+++ b/rust/tvm-ffi-sys/src/c_env_api.rs
@@ -53,13 +53,13 @@ unsafe extern "C" {
pub fn TVMFFIEnvGetStream(device_type: i32, device_id: i32) ->
TVMFFIStreamHandle;
- pub fn TVMFFIEnvSetTensorAllocator(
+ pub fn TVMFFIEnvSetDLPackManagedTensorAllocator(
allocator: DLPackManagedTensorAllocator,
write_to_global_context: i32,
opt_out_original_allocator: *mut DLPackManagedTensorAllocator,
) -> i32;
- pub fn TVMFFIEnvGetTensorAllocator() -> DLPackManagedTensorAllocator;
+ pub fn TVMFFIEnvGetDLPackManagedTensorAllocator() ->
DLPackManagedTensorAllocator;
pub fn TVMFFIEnvCheckSignals() -> i32;
diff --git a/src/ffi/extra/env_context.cc b/src/ffi/extra/env_context.cc
index cb68b53..e8f5a63 100644
--- a/src/ffi/extra/env_context.cc
+++ b/src/ffi/extra/env_context.cc
@@ -22,6 +22,7 @@
* \brief A minimalistic env context based on ffi values.
*/
+#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
@@ -106,16 +107,50 @@ TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t
device_type, int32_t device_id) {
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetStream);
}
-int TVMFFIEnvSetTensorAllocator(DLPackManagedTensorAllocator allocator, int
write_to_global_context,
- DLPackManagedTensorAllocator*
opt_out_original_allocator) {
+int TVMFFIEnvSetDLPackManagedTensorAllocator(
+ DLPackManagedTensorAllocator allocator, int write_to_global_context,
+ DLPackManagedTensorAllocator* opt_out_original_allocator) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::EnvContext::ThreadLocal()->SetDLPackManagedTensorAllocator(
allocator, write_to_global_context, opt_out_original_allocator);
TVM_FFI_SAFE_CALL_END();
}
-DLPackManagedTensorAllocator TVMFFIEnvGetTensorAllocator() {
+DLPackManagedTensorAllocator TVMFFIEnvGetDLPackManagedTensorAllocator() {
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
return
tvm::ffi::EnvContext::ThreadLocal()->GetDLPackManagedTensorAllocator();
- TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetTensorAllocator);
+ TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetDLPackManagedTensorAllocator);
+}
+
+void TVMFFIEnvTensorAllocSetError(void* error_ctx, const char* kind, const
char* message) {
+ TVMFFIErrorSetRaisedFromCStr(kind, message);
+}
+
+int TVMFFIEnvTensorAlloc(DLTensor* prototype, TVMFFIObjectHandle* out) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ DLPackManagedTensorAllocator dlpack_alloc =
+ tvm::ffi::EnvContext::ThreadLocal()->GetDLPackManagedTensorAllocator();
+ if (dlpack_alloc == nullptr) {
+ TVMFFIErrorSetRaisedFromCStr(
+ "RuntimeError",
+ "TVMFFIEnvTensorAlloc: allocator is nullptr, "
+ "likely because TVMFFIEnvSetDLPackManagedTensorAllocator has not been
called.");
+ return -1;
+ }
+ DLManagedTensorVersioned* dlpack_tensor = nullptr;
+ int ret = (*dlpack_alloc)(prototype, &dlpack_tensor, nullptr,
TVMFFIEnvTensorAllocSetError);
+ TVM_FFI_ICHECK(dlpack_tensor != nullptr);
+ if (ret != 0) return ret;
+ if (dlpack_tensor->dl_tensor.strides != nullptr ||
dlpack_tensor->dl_tensor.ndim == 0) {
+ *out = tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(
+
tvm::ffi::make_object<tvm::ffi::details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
+ dlpack_tensor, /*extra_strides_at_tail=*/false));
+ } else {
+ *out = tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(
+ tvm::ffi::make_inplace_array_object<
+ tvm::ffi::details::TensorObjFromDLPack<DLManagedTensorVersioned>,
int64_t>(
+ dlpack_tensor->dl_tensor.ndim, dlpack_tensor,
+ /*extra_strides_at_tail=*/true));
+ }
+ TVM_FFI_SAFE_CALL_END();
}
diff --git a/tests/cpp/extra/test_c_env_api.cc
b/tests/cpp/extra/test_c_env_api.cc
new file mode 100644
index 0000000..ff8a845
--- /dev/null
+++ b/tests/cpp/extra/test_c_env_api.cc
@@ -0,0 +1,86 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/extra/c_env_api.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+struct CPUNDAlloc {
+ void AllocData(DLTensor* tensor) { tensor->data =
malloc(GetDataSize(*tensor)); }
+ void FreeData(DLTensor* tensor) { free(tensor->data); }
+};
+
+inline Tensor Empty(const Shape& shape, DLDataType dtype, DLDevice device) {
+ return Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device);
+}
+
+int TestDLPackManagedTensorAllocator(DLTensor* prototype,
DLManagedTensorVersioned** out,
+ void* error_ctx,
+ void (*SetError)(void* error_ctx, const
char* kind,
+ const char* message)) {
+ Shape shape(prototype->shape, prototype->shape + prototype->ndim);
+ Tensor nd = Empty(shape, prototype->dtype, prototype->device);
+ *out = nd.ToDLPackVersioned();
+ return 0;
+}
+
+int TestDLPackManagedTensorAllocatorError(DLTensor* prototype,
DLManagedTensorVersioned** out,
+ void* error_ctx,
+ void (*SetError)(void* error_ctx,
const char* kind,
+ const char*
message)) {
+ SetError(error_ctx, "RuntimeError", "TestDLPackManagedTensorAllocatorError");
+ return -1;
+}
+
+TEST(CEnvAPI, TVMFFIEnvTensorAlloc) {
+ auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
+ TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator,
0, nullptr);
+ Tensor tensor = Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, {1, 2, 3},
+ DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
+ EXPECT_EQ(tensor.use_count(), 1);
+ EXPECT_EQ(tensor.shape().size(), 3);
+ EXPECT_EQ(tensor.size(0), 1);
+ EXPECT_EQ(tensor.size(1), 2);
+ EXPECT_EQ(tensor.size(2), 3);
+ EXPECT_EQ(tensor.dtype().code, kDLFloat);
+ EXPECT_EQ(tensor.dtype().bits, 32);
+ EXPECT_EQ(tensor.dtype().lanes, 1);
+ EXPECT_EQ(tensor.device().device_type, kDLCPU);
+ EXPECT_EQ(tensor.device().device_id, 0);
+ EXPECT_NE(tensor.data_ptr(), nullptr);
+ TVMFFIEnvSetDLPackManagedTensorAllocator(old_allocator, 0, nullptr);
+}
+
+TEST(CEnvAPI, TVMFFIEnvTensorAllocError) {
+ auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
+
TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocatorError,
0, nullptr);
+ EXPECT_THROW(
+ {
+ Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, {1, 2, 3},
DLDataType({kDLFloat, 32, 1}),
+ DLDevice({kDLCPU, 0}));
+ },
+ tvm::ffi::Error);
+ TVMFFIEnvSetDLPackManagedTensorAllocator(old_allocator, 0, nullptr);
+}
+
+} // namespace
diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc
index 60d9a9a..41f7041 100644
--- a/tests/cpp/test_tensor.cc
+++ b/tests/cpp/test_tensor.cc
@@ -32,21 +32,15 @@ inline Tensor Empty(const Shape& shape, DLDataType dtype,
DLDevice device) {
return Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device);
}
-int TestDLPackManagedTensorAllocator(DLTensor* prototype,
DLManagedTensorVersioned** out,
- void* error_ctx,
- void (*SetError)(void* error_ctx, const
char* kind,
- const char* message)) {
+int TestEnvTensorAllocator(DLTensor* prototype, TVMFFIObjectHandle* out) {
Shape shape(prototype->shape, prototype->shape + prototype->ndim);
Tensor nd = Empty(shape, prototype->dtype, prototype->device);
- *out = nd.ToDLPackVersioned();
+ *out =
tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd));
return 0;
}
-int TestDLPackManagedTensorAllocatorError(DLTensor* prototype,
DLManagedTensorVersioned** out,
- void* error_ctx,
- void (*SetError)(void* error_ctx,
const char* kind,
- const char*
message)) {
- SetError(error_ctx, "RuntimeError", "TestDLPackManagedTensorAllocatorError");
+int TestEnvTensorAllocatorError(DLTensor* prototype, TVMFFIObjectHandle* out) {
+ TVMFFIErrorSetRaisedFromCStr("RuntimeError", "TestEnvTensorAllocatorError");
return -1;
}
@@ -137,10 +131,10 @@ TEST(Tensor, DLPackVersioned) {
EXPECT_EQ(tensor.use_count(), 1);
}
-TEST(Tensor, DLPackAlloc) {
+TEST(Tensor, EnvAlloc) {
// Test successful allocation
- Tensor tensor = Tensor::FromDLPackAlloc(TestDLPackManagedTensorAllocator,
{1, 2, 3},
- DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
+ Tensor tensor = Tensor::FromEnvAlloc(TestEnvTensorAllocator, {1, 2, 3},
+ DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
EXPECT_EQ(tensor.use_count(), 1);
EXPECT_EQ(tensor.shape().size(), 3);
EXPECT_EQ(tensor.size(0), 1);
@@ -154,12 +148,12 @@ TEST(Tensor, DLPackAlloc) {
EXPECT_NE(tensor.data_ptr(), nullptr);
}
-TEST(Tensor, DLPackAllocError) {
+TEST(Tensor, EnvAllocError) {
// Test error handling in DLPackAlloc
EXPECT_THROW(
{
- Tensor::FromDLPackAlloc(TestDLPackManagedTensorAllocatorError, {1, 2,
3},
- DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
+ Tensor::FromEnvAlloc(TestEnvTensorAllocatorError, {1, 2, 3},
DLDataType({kDLFloat, 32, 1}),
+ DLDevice({kDLCPU, 0}));
},
tvm::ffi::Error);
}
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 5b3e18d..f39c485 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -232,8 +232,8 @@ def test_load_inline_with_env_tensor_allocator() -> None:
TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float
tensor";
// allocate a new tensor with the env tensor allocator
// it will be redirected to torch.empty when calling the function
- ffi::Tensor y = ffi::Tensor::FromDLPackAlloc(
- TVMFFIEnvGetTensorAllocator(), ffi::Shape({x.size(0)}),
f32_dtype, x.device());
+ ffi::Tensor y = ffi::Tensor::FromEnvAlloc(
+ TVMFFIEnvTensorAlloc, ffi::Shape({x.size(0)}), f32_dtype,
x.device());
int64_t n = x.size(0);
for (int i = 0; i < n; ++i) {
static_cast<float*>(y.data_ptr())[i] =
static_cast<float*>(x.data_ptr())[i] + 1;
@@ -343,8 +343,8 @@ def test_cuda_memory_alloc_noleak() -> None:
namespace ffi = tvm::ffi;
ffi::Tensor return_tensor(tvm::ffi::TensorView x) {
- ffi::Tensor y = ffi::Tensor::FromDLPackAlloc(
- TVMFFIEnvGetTensorAllocator(), x.shape(), x.dtype(),
x.device());
+ ffi::Tensor y = ffi::Tensor::FromEnvAlloc(
+ TVMFFIEnvTensorAlloc, x.shape(), x.dtype(), x.device());
return y;
}
""",