This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 1ec6236 [FEAT] Introduce TensorView for non-owning view (#81)
1ec6236 is described below
commit 1ec623678adea0ddba482d8d56d4ab2be440e694
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Oct 1 16:03:57 2025 -0400
[FEAT] Introduce TensorView for non-owning view (#81)
When exposing an FFI function, sometimes it is not always possible to
pass in an owning reference of a Tensor, in such case, it is better to
interface with the Tensor using non-owning view.
This PR adds a ffi::TensorView class to represent in-memory non-owning
DLTensor view. It can be converted from DLTensorPtr* in the ffi argument
but cannot be promoted to an owning Tensor.
We recommend kernel ops to make use of TensorView when possible so the
exposed ops can support broad range of inputs including non-owning ones.
---
docs/get_started/quick_start.md | 6 +-
docs/guides/packaging.md | 2 +-
docs/guides/python_guide.md | 2 +-
examples/inline_module/main.py | 6 +-
examples/packaging/src/extension.cc | 2 +-
examples/quick_start/run_example.py | 4 +-
examples/quick_start/src/add_one_cpu.cc | 2 +-
examples/quick_start/src/add_one_cuda.cu | 2 +-
include/tvm/ffi/container/tensor.h | 159 +++++++++++++++++++++++++++
include/tvm/ffi/object.h | 2 +
pyproject.toml | 2 +-
python/tvm_ffi/__init__.py | 2 +-
python/tvm_ffi/cpp/load_inline.py | 4 +-
rust/tvm-ffi/scripts/generate_example_lib.py | 2 +-
tests/cpp/test_tensor.cc | 24 ++++
tests/python/test_build_inline.py | 2 +-
tests/python/test_load_inline.py | 18 +--
17 files changed, 213 insertions(+), 28 deletions(-)
diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md
index 9a4ae3d..2ad8a41 100644
--- a/docs/get_started/quick_start.md
+++ b/docs/get_started/quick_start.md
@@ -103,7 +103,7 @@ namespace tvm_ffi_example {
namespace ffi = tvm::ffi;
-void AddOne(ffi::Tensor x, ffi::Tensor y) {
+void AddOne(ffi::TensorView x, ffi::TensorView y) {
// Validate inputs
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -131,7 +131,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu,
tvm_ffi_example::AddOne);
### CUDA Implementation
```cpp
-void AddOneCUDA(ffi::Tensor x, ffi::Tensor y) {
+void AddOneCUDA(ffi::TensorView x, ffi::TensorView y) {
// Validation (same as CPU version)
// ...
@@ -216,7 +216,7 @@ shows how to run the example exported function in C++.
namespace ffi = tvm::ffi;
-void CallAddOne(ffi::Tensor x, ffi::Tensor y) {
+void CallAddOne(ffi::TensorView x, ffi::TensorView y) {
ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so");
ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
add_one_cpu(x, y);
diff --git a/docs/guides/packaging.md b/docs/guides/packaging.md
index a50f413..2a82e0a 100644
--- a/docs/guides/packaging.md
+++ b/docs/guides/packaging.md
@@ -149,7 +149,7 @@ which can later be accessed via `tvm_ffi.load_module`.
Here's a basic example of the function implementation:
```c++
-void AddOne(ffi::Tensor x, ffi::Tensor y) {
+void AddOne(ffi::TensorView x, ffi::TensorView y) {
// ... implementation omitted ...
}
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index de2adfa..7b977e7 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -151,7 +151,7 @@ import tvm_ffi.cpp
# define the cpp source code
cpp_source = '''
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py
index 7231132..d0ba6fe 100644
--- a/examples/inline_module/main.py
+++ b/examples/inline_module/main.py
@@ -26,7 +26,7 @@ def main() -> None:
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -39,7 +39,7 @@ def main() -> None:
}
}
- void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y);
+ void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y);
""",
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
@@ -49,7 +49,7 @@ def main() -> None:
}
}
- void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/examples/packaging/src/extension.cc
b/examples/packaging/src/extension.cc
index 6a7324f..c99bd07 100644
--- a/examples/packaging/src/extension.cc
+++ b/examples/packaging/src/extension.cc
@@ -44,7 +44,7 @@ namespace ffi = tvm::ffi;
*/
void RaiseError(ffi::String msg) { TVM_FFI_THROW(RuntimeError) << msg; }
-void AddOne(ffi::Tensor x, ffi::Tensor y) {
+void AddOne(ffi::TensorView x, ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/examples/quick_start/run_example.py
b/examples/quick_start/run_example.py
index 2e2b7f3..65c188d 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -28,7 +28,7 @@ def run_add_one_cpu() -> None:
x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
y = numpy.empty_like(x)
# tvm-ffi automatically handles DLPack compatible tensors
- # torch tensors can be viewed as ffi::Tensor or DLTensor*
+ # torch tensors can be viewed as ffi::TensorView
# in the background
mod.add_one_cpu(x, y)
print("numpy.result after add_one(x, y)")
@@ -37,7 +37,7 @@ def run_add_one_cpu() -> None:
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
# tvm-ffi automatically handles DLPack compatible tensors
- # torch tensors can be viewed as ffi::Tensor or DLTensor*
+ # torch tensors can be viewed as ffi::TensorView
# in the background
mod.add_one_cpu(x, y)
print("torch.result after add_one(x, y)")
diff --git a/examples/quick_start/src/add_one_cpu.cc
b/examples/quick_start/src/add_one_cpu.cc
index 76b9b37..886af13 100644
--- a/examples/quick_start/src/add_one_cpu.cc
+++ b/examples/quick_start/src/add_one_cpu.cc
@@ -23,7 +23,7 @@
namespace tvm_ffi_example {
-void AddOne(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/examples/quick_start/src/add_one_cuda.cu
b/examples/quick_start/src/add_one_cuda.cu
index 52f1e74..b15f807 100644
--- a/examples/quick_start/src/add_one_cuda.cu
+++ b/examples/quick_start/src/add_one_cuda.cu
@@ -31,7 +31,7 @@ __global__ void AddOneKernel(float* x, float* y, int n) {
}
}
-void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+void AddOneCUDA(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/include/tvm/ffi/container/tensor.h
b/include/tvm/ffi/container/tensor.h
index 675de51..5197d4e 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -458,6 +458,165 @@ class Tensor : public ObjectRef {
TensorObj* get_mutable() const { return const_cast<TensorObj*>(get()); }
};
+/*!
+ * \brief A non-owning view of a Tensor.
+ *
+ * This class stores a light-weight non-owning view of a Tensor.
+ * This is useful for accessing tensor data without retaining a strong
reference to the Tensor.
+ * Since the caller may not always be able to pass in a strong referenced
tensor.
+ *
+ * It is the user's responsibility to ensure
+ * that the underlying tensor data outlives the `TensorView`.
+ * This responsibility extends to all data pointed to by the underlying
DLTensor.
+ * This includes not only the tensor elements in data but also the memory for
shape and strides.
+ *
+ * When exposing a function that expects only expects a TensorView, we
recommend using
+ * ffi::TensorView as the argument type instead of ffi::Tensor.
+ */
+class TensorView {
+ public:
+ /*!
+ * \brief Create a TensorView from a Tensor.
+ * \param tensor The input Tensor.
+ */
+ TensorView(const Tensor& tensor) { // NOLINT(*)
+ TVM_FFI_ICHECK(tensor.defined());
+ tensor_ = *tensor.operator->();
+ } // NOLINT(*)
+ /*!
+ * \brief Create a TensorView from a DLTensor.
+ * \param tensor The input DLTensor.
+ */
+ TensorView(const DLTensor* tensor) { // NOLINT(*)
+ TVM_FFI_ICHECK(tensor != nullptr);
+ tensor_ = *tensor;
+ }
+ /*!
+ * \brief Copy constructor.
+ * \param tensor The input TensorView.
+ */
+ TensorView(const TensorView& tensor) = default;
+ /*!
+ * \brief Move constructor.
+ * \param tensor The input TensorView.
+ */
+ TensorView(TensorView&& tensor) = default;
+ /*!
+ * \brief Copy assignment operator.
+ * \param tensor The input TensorView.
+ * \return The created TensorView.
+ */
+ TensorView& operator=(const TensorView& tensor) = default;
+ /*!
+ * \brief Move assignment operator.
+ * \param tensor The input TensorView.
+ * \return The created TensorView.
+ */
+ TensorView& operator=(TensorView&& tensor) = default;
+ /*!
+ * \brief Assignment operator from a Tensor.
+ * \param tensor The input Tensor.
+ * \return The created TensorView.
+ */
+ TensorView& operator=(const Tensor& tensor) {
+ TVM_FFI_ICHECK(tensor.defined());
+ tensor_ = *tensor.operator->();
+ return *this;
+ }
+
+ // explicitly delete move constructor
+ TensorView(Tensor&& tensor) = delete; // NOLINT(*)
+ // delete move assignment operator from owned tensor
+ TensorView& operator=(Tensor&& tensor) = delete;
+ /*!
+ * \brief Get the underlying DLTensor pointer.
+ * \return The underlying DLTensor pointer.
+ */
+ const DLTensor* operator->() const { return &tensor_; }
+
+ /*!
+ * \brief Get the shape of the Tensor.
+ * \return The shape of the Tensor.
+ */
+ ShapeView shape() const { return ShapeView(tensor_.shape, tensor_.ndim); }
+
+ /*!
+ * \brief Get the strides of the Tensor.
+ * \return The strides of the Tensor.
+ */
+ ShapeView strides() const {
+ TVM_FFI_ICHECK(tensor_.strides != nullptr);
+ return ShapeView(tensor_.strides, tensor_.ndim);
+ }
+
+ /*!
+ * \brief Get the data pointer of the Tensor.
+ * \return The data pointer of the Tensor.
+ */
+ void* data_ptr() const { return tensor_.data; }
+
+ /*!
+ * \brief Get the number of dimensions in the Tensor.
+ * \return The number of dimensions in the Tensor.
+ */
+ int32_t ndim() const { return tensor_.ndim; }
+
+ /*!
+ * \brief Get the number of elements in the Tensor.
+ * \return The number of elements in the Tensor.
+ */
+ int64_t numel() const { return this->shape().Product(); }
+
+ /*!
+ * \brief Get the data type of the Tensor.
+ * \return The data type of the Tensor.
+ */
+ DLDataType dtype() const { return tensor_.dtype; }
+
+ /*!
+ * \brief Check if the Tensor is contiguous.
+ * \return True if the Tensor is contiguous, false otherwise.
+ */
+ bool IsContiguous() const { return tvm::ffi::IsContiguous(tensor_); }
+
+ private:
+ DLTensor tensor_;
+};
+
+// TensorView type, allow implicit casting from DLTensor*
+// NOTE: we deliberately do not support MoveToAny and MoveFromAny since it
does not retain ownership
+template <>
+struct TypeTraits<TensorView> : public TypeTraitsBase {
+ static constexpr bool storage_enabled = false;
+ static constexpr int32_t field_static_type_index =
TypeIndex::kTVMFFIDLTensorPtr;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const TensorView& src, TVMFFIAny*
result) {
+ result->type_index = TypeIndex::kTVMFFIDLTensorPtr;
+ result->zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
+ result->v_ptr = const_cast<DLTensor*>(src.operator->());
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFIDLTensorPtr;
+ }
+
+ TVM_FFI_INLINE static TensorView CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ return TensorView(static_cast<DLTensor*>(src->v_ptr));
+ }
+
+ TVM_FFI_INLINE static std::optional<TensorView> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) {
+ return TensorView(static_cast<DLTensor*>(src->v_ptr));
+ } else if (src->type_index == TypeIndex::kTVMFFITensor) {
+ return TensorView(TVMFFITensorGetDLTensorPtr(src->v_obj));
+ }
+ return std::nullopt;
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() { return
StaticTypeKey::kTVMFFIDLTensorPtr; }
+};
+
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 6eac9a4..796173f 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -82,6 +82,8 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIDataType = "DataType";
/*! \brief The type key for Device */
static constexpr const char* kTVMFFIDevice = "Device";
+ /*! \brief The type key for DLTensor* */
+ static constexpr const char* kTVMFFIDLTensorPtr = "DLTensor*";
/*! \brief The type key for const char* */
static constexpr const char* kTVMFFIRawStr = "const char*";
/*! \brief The type key for TVMFFIByteArray* */
diff --git a/pyproject.toml b/pyproject.toml
index ff52507..b655e88 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b12"
+version = "0.1.0b13"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 9a9ed83..112cd2b 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
"""TVM FFI Python package."""
# version
-__version__ = "0.1.0b12"
+__version__ = "0.1.0b13"
# order matters here so we need to skip isort here
# isort: skip_file
diff --git a/python/tvm_ffi/cpp/load_inline.py
b/python/tvm_ffi/cpp/load_inline.py
index 3d1d5b5..623ff88 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -453,7 +453,7 @@ def build_inline( # noqa: PLR0915, PLR0912
# define the cpp source code
cpp_source = '''
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -661,7 +661,7 @@ def load_inline(
# define the cpp source code
cpp_source = '''
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/rust/tvm-ffi/scripts/generate_example_lib.py
b/rust/tvm-ffi/scripts/generate_example_lib.py
index be38f17..43e822b 100644
--- a/rust/tvm-ffi/scripts/generate_example_lib.py
+++ b/rust/tvm-ffi/scripts/generate_example_lib.py
@@ -32,7 +32,7 @@ def main() -> None:
output_lib_path = tvm_ffi.cpp.build_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc
index bb9158a..f746a4a 100644
--- a/tests/cpp/test_tensor.cc
+++ b/tests/cpp/test_tensor.cc
@@ -163,4 +163,28 @@ TEST(Tensor, DLPackAllocError) {
tvm::ffi::Error);
}
+TEST(Tensor, TensorView) {
+ Tensor tensor = Empty({1, 2, 3}, DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
+ TensorView tensor_view = tensor;
+
+ EXPECT_EQ(tensor_view.shape().size(), 3);
+ EXPECT_EQ(tensor_view.shape()[0], 1);
+ EXPECT_EQ(tensor_view.shape()[1], 2);
+ EXPECT_EQ(tensor_view.shape()[2], 3);
+ EXPECT_EQ(tensor_view.dtype().code, kDLFloat);
+ EXPECT_EQ(tensor_view.dtype().bits, 32);
+ EXPECT_EQ(tensor_view.dtype().lanes, 1);
+
+ AnyView result = tensor_view;
+ EXPECT_EQ(result.type_index(), TypeIndex::kTVMFFIDLTensorPtr);
+ TensorView tensor_view2 = result.as<TensorView>().value();
+ EXPECT_EQ(tensor_view2.shape().size(), 3);
+ EXPECT_EQ(tensor_view2.shape()[0], 1);
+ EXPECT_EQ(tensor_view2.shape()[1], 2);
+ EXPECT_EQ(tensor_view2.shape()[2], 3);
+ EXPECT_EQ(tensor_view2.dtype().code, kDLFloat);
+ EXPECT_EQ(tensor_view2.dtype().bits, 32);
+ EXPECT_EQ(tensor_view2.dtype().lanes, 1);
+}
+
} // namespace
diff --git a/tests/python/test_build_inline.py
b/tests/python/test_build_inline.py
index d436271..5e3b191 100644
--- a/tests/python/test_build_inline.py
+++ b/tests/python/test_build_inline.py
@@ -24,7 +24,7 @@ def test_build_inline_cpp() -> None:
output_lib_path = tvm_ffi.cpp.build_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 229dc62..cd61c57 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -36,7 +36,7 @@ def test_load_inline_cpp() -> None:
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -62,7 +62,7 @@ def test_load_inline_cpp_with_docstrings() -> None:
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -89,7 +89,7 @@ def test_load_inline_cpp_multiple_sources() -> None:
name="hello",
cpp_sources=[
r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -103,7 +103,7 @@ def test_load_inline_cpp_multiple_sources() -> None:
}
""",
r"""
- void add_two_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_two_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -130,7 +130,7 @@ def test_load_inline_cpp_build_dir() -> None:
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -167,7 +167,7 @@ def test_load_inline_cuda() -> None:
}
}
- void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y, int64_t
raw_stream) {
+ void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y,
int64_t raw_stream) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -261,7 +261,7 @@ def test_load_inline_both() -> None:
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -274,7 +274,7 @@ def test_load_inline_both() -> None:
}
}
- void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y);
+ void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y);
""",
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
@@ -284,7 +284,7 @@ def test_load_inline_both() -> None:
}
}
- void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};