This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c3fc8f7 [Feature] Support vanilla C++ STL in FFI. (#228)
c3fc8f7 is described below
commit c3fc8f7f0e95a97beed342b6ddec4c3f6add0441
Author: DarkSharpness <[email protected]>
AuthorDate: Sun Nov 30 08:14:44 2025 +0800
[Feature] Support vanilla C++ STL in FFI. (#228)
Similar to pybind, we add a `stl.h` which support `array`, `vector`,
`tuple`, `optional` and `variant`. After this file is included, users
can use native C++ components, which could hopefully improve
compatibility and reduce manual effort in converting between tvm::ffi
components to C++ STL components.
~~We also modify the `function_detail.h` a little, so that we support
all kinds of argument type (`T`, `const T`, `T&`, `const T&`, `T&&`,
`const T&&` have been tested) in C++ exported functions.~~
Example code:
```cpp
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/stl.h>
#include <tvm/ffi/function.h>
#include <algorithm>
#include <array>
#include <cstddef>
#include <numeric>
#include <optional>
#include <variant>
#include <vector>
namespace {
// optional, array, vector, tuple is supported
auto sum_row(std::optional<std::vector<std::array<int, 2>>> arg)
-> std::tuple<bool, std::vector<int>> {
if (arg) {
auto result = std::vector<int>{};
result.reserve(arg->size());
for (const auto& row : *arg) {
result.push_back(std::accumulate(row.begin(), row.end(), 0));
}
return {true, result};
} else {
return {false, {}};
}
}
// const reference is also supported, though not recommended and won't
bring performance gain
// all types must be cast to value, and then pass by reference
auto find_diff(const std::vector<int>& a, std::vector<int> b) ->
std::size_t {
const auto max_pos = std::min(a.size(), b.size());
for (std::size_t i = 0; i < max_pos; ++i) {
if (a[i] != b[i]) {
return i;
}
}
return max_pos;
}
auto test_variant(std::variant<int, float, std::vector<int>>&& arg)
-> std::variant<int, std::vector<int>> {
if (std::holds_alternative<int>(arg)) {
std::vector<int> result;
auto& value = std::get<int>(arg);
result.reserve(value);
for (int i = 0; i < value; ++i) {
result.push_back(i);
}
return result;
} else if (std::holds_alternative<float>(arg)) {
return static_cast<int>(std::get<float>(arg));
} else {
auto& value = std::get<std::vector<int>>(arg);
std::reverse(value.begin(), value.end());
return std::move(value);
}
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(sum_row, sum_row);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(find_diff, find_diff);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_variant, test_variant);
} // namespace
} // namespace
```
Python part:
```python
from __future__ import annotations
from tvm_ffi.cpp import load_inline
from pathlib import Path
cur_path = Path(__file__).parent
with open(cur_path / "stl.cpp") as f:
cpp_source = f.read()
module = load_inline(
"test_stl",
cpp_sources = cpp_source,
)
print(module.sum_row([[1, 2], [3, 4]])) # Expected output: (True, [3, 7])
print(module.sum_row(None)) # Expected output: (False, [])
print(module.find_diff([1, 2, 3, 4], [1, 2, 4, 3])) # Expected output: 2
(index = 2)
print(module.test_variant(2)) # Expected output: [0, 1]
print(module.test_variant(3.1)) # Expected output: 3
print(module.test_variant([1, 2])) # Expected output: [2, 1]
```
---
include/tvm/ffi/container/map.h | 3 +
include/tvm/ffi/extra/stl.h | 649 +++++++++++++++++++++++++++++++++++++++
tests/python/cpp_src/test_stl.cc | 99 ++++++
tests/python/test_stl.py | 51 +++
4 files changed, 802 insertions(+)
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index b948fc8..526b345 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -246,6 +246,9 @@ class MapObj : public Object {
// Reference class
template <typename, typename, typename>
friend class Map;
+
+ template <typename, typename>
+ friend struct TypeTraits;
};
/*! \brief A specialization of small-sized hash map */
diff --git a/include/tvm/ffi/extra/stl.h b/include/tvm/ffi/extra/stl.h
new file mode 100644
index 0000000..523f77f
--- /dev/null
+++ b/include/tvm/ffi/extra/stl.h
@@ -0,0 +1,649 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/extra/stl.h
+ * \brief STL container support.
+ * \note This file is an extra extension of TVM FFI,
+ * which provides support for STL containers in C++ exported functions.
+ *
+ * Whenever possible, prefer using tvm/ffi/container/ implementations,
+ * such as `tvm::ffi::Array` and `tvm::ffi::Tuple`, over STL containers.
+ *
+ * Native ffi objects comes with stable data layout and can be directly
accessed
+ * through compiled languages (Rust) and DSLs(via LLVM) with raw pointer access
+ * for better performance and compatibility.
+ */
+#ifndef TVM_FFI_EXTRA_STL_H_
+#define TVM_FFI_EXTRA_STL_H_
+
+#include <tvm/ffi/base_details.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/type_traits.h>
+
+#include <algorithm>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <exception>
+#include <functional>
+#include <iterator>
+#include <map>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "tvm/ffi/function.h"
+
+namespace tvm {
+namespace ffi {
+namespace details {
+
+struct STLTypeMismatch : public std::exception {
+ const char* what() const noexcept override { return "STL type mismatch"; }
+};
+
+struct STLTypeTrait : public TypeTraitsBase {
+ public:
+ static constexpr bool storage_enabled = false;
+
+ protected:
+ /// NOTE: we always copy STL types into an Object first, then move the
ObjectPtr to Any.
+ template <typename T>
+ TVM_FFI_INLINE static void MoveToAnyImpl(ObjectPtr<T>&& src, TVMFFIAny*
result) {
+ TVMFFIObject* obj_ptr =
ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(src));
+ result->type_index = obj_ptr->type_index;
+ result->zero_padding = 0;
+ TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
+ result->v_obj = obj_ptr;
+ }
+
+ /// NOTE: we always construct STL types from an Object first, then copy from
the ObjectPtr in Any.
+ template <typename T>
+ TVM_FFI_INLINE static ObjectPtr<T> CopyFromAnyImpl(const TVMFFIAny* src) {
+ return ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj);
+ }
+
+ /// NOTE: STL types are not natively movable from Any, so we always make a
new copy.
+ template <typename T>
+ TVM_FFI_INLINE static T ConstructFromAny(const Any& value) {
+ using TypeTrait = TypeTraits<T>;
+ if constexpr (std::is_same_v<T, Any>) {
+ return value;
+ } else {
+ /// NOTE: Not all type support `CheckArgStrict` (e.g. std::string),
+ /// so we use `TryCast` instead (without any prior check).
+ auto opt =
TypeTrait::TryCastFromAnyView(AnyUnsafe::TVMFFIAnyPtrFromAny(value));
+ if (!opt.has_value()) {
+ throw STLTypeMismatch{};
+ }
+ return std::move(*opt);
+ }
+ }
+};
+
+struct ListTemplate {};
+struct MapTemplate {};
+
+} // namespace details
+
+template <>
+struct TypeTraits<details::ListTemplate> : public details::STLTypeTrait {
+ public:
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
+
+ private:
+ template <std::size_t... Is, typename Tuple>
+ TVM_FFI_INLINE static ObjectPtr<ArrayObj>
CopyToTupleImpl(std::index_sequence<Is...>,
+ Tuple&& src) {
+ auto array = ArrayObj::Empty(static_cast<std::int64_t>(sizeof...(Is)));
+ auto dst = array->MutableBegin();
+ // increase size after each new to ensure exception safety
+ ((::new (dst++) Any(std::get<Is>(std::forward<Tuple>(src))),
array->size_++), ...);
+ return array;
+ }
+
+ template <typename Iter>
+ TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToArrayImpl(Iter src,
std::size_t size) {
+ auto array = ArrayObj::Empty(static_cast<std::int64_t>(size));
+ auto dst = array->MutableBegin();
+ // increase size after each new to ensure exception safety
+ for (std::size_t i = 0; i < size; ++i) {
+ ::new (dst++) Any(*(src++));
+ array->size_++;
+ }
+ return array;
+ }
+
+ protected:
+ template <typename Tuple>
+ TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToTuple(const Tuple& src) {
+ return
CopyToTupleImpl(std::make_index_sequence<std::tuple_size_v<Tuple>>{}, src);
+ }
+
+ template <typename Tuple>
+ TVM_FFI_INLINE static ObjectPtr<ArrayObj> MoveToTuple(Tuple&& src) {
+ return
CopyToTupleImpl(std::make_index_sequence<std::tuple_size_v<Tuple>>{},
+ std::forward<Tuple>(src));
+ }
+
+ template <typename Range>
+ TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToArray(const Range& src) {
+ return CopyToArrayImpl(std::begin(src), std::size(src));
+ }
+
+ template <typename Range>
+ TVM_FFI_INLINE static ObjectPtr<ArrayObj> MoveToArray(Range&& src) {
+ return CopyToArrayImpl(std::make_move_iterator(std::begin(src)),
std::size(src));
+ }
+};
+
+template <>
+struct TypeTraits<details::MapTemplate> : public details::STLTypeTrait {
+ public:
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap;
+
+ protected:
+ template <typename MapType>
+ TVM_FFI_INLINE static ObjectPtr<Object> CopyToMap(const MapType& src) {
+ return MapObj::CreateFromRange(std::begin(src), std::end(src));
+ }
+
+ template <typename MapType>
+ TVM_FFI_INLINE static ObjectPtr<Object> MoveToMap(MapType&& src) {
+ return MapObj::CreateFromRange(std::make_move_iterator(std::begin(src)),
+ std::make_move_iterator(std::end(src)));
+ }
+
+ template <typename MapType, bool CanReserve>
+ TVM_FFI_INLINE static MapType ConstructMap(const TVMFFIAny* src) {
+ using KeyType = typename MapType::key_type;
+ using ValueType = typename MapType::mapped_type;
+ auto result = MapType{};
+ auto map_obj = CopyFromAnyImpl<MapObj>(src);
+ if constexpr (CanReserve) {
+ result.reserve(map_obj->size());
+ }
+ for (const auto& [key, value] : *map_obj) {
+ result.try_emplace(ConstructFromAny<KeyType>(key),
ConstructFromAny<ValueType>(value));
+ }
+ return result;
+ }
+};
+
+template <typename T, std::size_t Nm>
+struct TypeTraits<std::array<T, Nm>> : public
TypeTraits<details::ListTemplate> {
+ private:
+ using Self = std::array<T, Nm>;
+
+ TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+ if (src->type_index != TypeIndex::kTVMFFIArray) return false;
+ const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
+ return n.size_ == Nm;
+ }
+
+ public:
+ static_assert(Nm > 0, "Zero-length std::array is not supported.");
+
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return MoveToAnyImpl(CopyToArray(src), result);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return MoveToAnyImpl(MoveToArray(std::move(src)), result);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (!CheckAnyFast(src)) return std::nullopt;
+ try {
+ auto array = CopyFromAnyImpl<ArrayObj>(src);
+ auto begin = array->MutableBegin();
+ Self result; // no initialization to avoid overhead
+ for (std::size_t i = 0; i < Nm; ++i) {
+ result[i] = ConstructFromAny<T>(begin[i]);
+ }
+ return result;
+ } catch (const details::STLTypeMismatch&) {
+ return std::nullopt;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return "std::array<" + details::Type2Str<T>::v() + ", " +
std::to_string(Nm) + ">";
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return R"({"type":"std::array","args":[)" + details::TypeSchema<T>::v() +
"," +
+ std::to_string(Nm) + "]}";
+ }
+};
+
+template <typename T>
+struct TypeTraits<std::vector<T>> : public TypeTraits<details::ListTemplate> {
+ private:
+ using Self = std::vector<T>;
+
+ TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFIArray;
+ }
+
+ public:
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return MoveToAnyImpl(CopyToArray(src), result);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return MoveToAnyImpl(MoveToArray(std::move(src)), result);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (!CheckAnyFast(src)) return std::nullopt;
+ try {
+ auto array = CopyFromAnyImpl<ArrayObj>(src);
+ auto begin = array->MutableBegin();
+ auto result = Self{};
+ auto length = array->size_;
+ result.reserve(length);
+ for (std::size_t i = 0; i < length; ++i) {
+ result.emplace_back(ConstructFromAny<T>(begin[i]));
+ }
+ return result;
+ } catch (const details::STLTypeMismatch&) {
+ return std::nullopt;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return "std::vector<" + details::Type2Str<T>::v() + ">";
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return R"({"type":"std::vector","args":[)" + details::TypeSchema<T>::v() +
"]}";
+ }
+};
+
+template <typename T>
+struct TypeTraits<std::optional<T>> : public TypeTraitsBase {
+ public:
+ using Self = std::optional<T>;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ if (src.has_value()) {
+ TypeTraits<T>::CopyToAnyView(*src, result);
+ } else {
+ TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result);
+ }
+ }
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ if (src.has_value()) {
+ TypeTraits<T>::MoveToAny(std::move(*src), result);
+ } else {
+ TypeTraits<std::nullptr_t>::MoveToAny(nullptr, result);
+ }
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) return true;
+ return TypeTraits<T>::CheckAnyStrict(src);
+ }
+
+ TVM_FFI_INLINE static Self CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
+ return TypeTraits<T>::CopyFromAnyViewAfterCheck(src);
+ }
+
+ TVM_FFI_INLINE static Self MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
+ return TypeTraits<T>::MoveFromAnyAfterCheck(src);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
+ auto result = std::optional<Self>{};
+ if (std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(src)) {
+ /// NOTE: std::optional<T> is just what we want (Self).
+ result.emplace(std::move(opt));
+ } else {
+ result.reset(); // failed to cast, indicate failure
+ }
+ return result;
+ }
+
+ TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+ return TypeTraits<T>::GetMismatchTypeInfo(src);
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return "std::optional<" + TypeTraits<T>::TypeStr() + ">";
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return R"({"type":"std::optional","args":[)" + details::TypeSchema<T>::v()
+ "]}";
+ }
+};
+
+template <typename... Args>
+struct TypeTraits<std::variant<Args...>> : public TypeTraitsBase {
+ private:
+ using Self = std::variant<Args...>;
+ static constexpr std::size_t Nm = sizeof...(Args);
+
+ template <std::size_t Is = 0>
+ TVM_FFI_INLINE static Self CopyUnsafeAux(const TVMFFIAny* src) {
+ if constexpr (Is >= Nm) {
+ TVM_FFI_ICHECK(false) << "Unreachable: variant TryCast failed.";
+ throw; // unreachable
+ } else {
+ using ElemType = std::variant_alternative_t<Is, Self>;
+ if (TypeTraits<ElemType>::CheckAnyStrict(src)) {
+ return Self{std::in_place_index<Is>,
TypeTraits<ElemType>::CopyFromAnyViewAfterCheck(src)};
+ } else {
+ return CopyUnsafeAux<Is + 1>(src);
+ }
+ }
+ }
+
+ template <std::size_t Is = 0>
+ TVM_FFI_INLINE static Self MoveUnsafeAux(const TVMFFIAny* src) {
+ if constexpr (Is >= Nm) {
+ TVM_FFI_ICHECK(false) << "Unreachable: variant TryCast failed.";
+ throw; // unreachable
+ } else {
+ using ElemType = std::variant_alternative_t<Is, Self>;
+ if (TypeTraits<ElemType>::CheckAnyStrict(src)) {
+ return Self{std::in_place_index<Is>,
TypeTraits<ElemType>::MoveFromAnyAfterCheck(src)};
+ } else {
+ return MoveUnsafeAux<Is + 1>(src);
+ }
+ }
+ }
+
+ template <std::size_t Is = 0>
+ TVM_FFI_INLINE static std::optional<Self> TryCastAux(const TVMFFIAny* src) {
+ if constexpr (Is >= Nm) {
+ return std::nullopt;
+ } else {
+ using ElemType = std::variant_alternative_t<Is, Self>;
+ if (auto opt = TypeTraits<ElemType>::TryCastFromAnyView(src)) {
+ return Self{std::in_place_index<Is>, std::move(*opt)};
+ } else {
+ return TryCastAux<Is + 1>(src);
+ }
+ }
+ }
+
+ public:
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return std::visit(
+ [&](const auto& value) {
+ using ValueType = std::decay_t<decltype(value)>;
+ TypeTraits<ValueType>::CopyToAnyView(value, result);
+ },
+ src);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return std::visit(
+ [&](auto&& value) {
+ using ValueType = std::decay_t<decltype(value)>;
+ TypeTraits<ValueType>::MoveToAny(std::forward<ValueType>(value),
result);
+ },
+ std::move(src));
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ return (TypeTraits<Args>::CheckAnyStrict(src) || ...);
+ }
+
+ TVM_FFI_INLINE static Self CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ // find the first possible type to copy
+ return CopyUnsafeAux(src);
+ }
+
+ TVM_FFI_INLINE static Self MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ // find the first possible type to move
+ return MoveUnsafeAux(src);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ // try to find the first possible type to copy
+ return TryCastAux(src);
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ std::ostringstream os;
+ os << "std::variant<";
+ const char* sep = "";
+ ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
+ os << ">";
+ return std::move(os).str();
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream os;
+ os << R"({"type":"std::variant","args":[)";
+ const char* sep = "";
+ ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
+ os << "]}";
+ return std::move(os).str();
+ }
+};
+
+template <typename... Args>
+struct TypeTraits<std::tuple<Args...>> : public
TypeTraits<details::ListTemplate> {
+ private:
+ using Self = std::tuple<Args...>;
+ static constexpr std::size_t Nm = sizeof...(Args);
+ static_assert(Nm > 0, "Zero-length std::tuple is not supported.");
+
+ TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+ if (src->type_index != TypeIndex::kTVMFFIArray) return false;
+ const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
+ return n.size_ == Nm;
+ }
+
+ template <std::size_t... Is>
+ TVM_FFI_INLINE static Self ConstructTupleAux(std::index_sequence<Is...>,
const ArrayObj& n) {
+ return Self{ConstructFromAny<std::tuple_element_t<Is, Self>>(n[Is])...};
+ }
+
+ public:
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return MoveToAnyImpl(CopyToTuple(src), result);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return MoveToAnyImpl(MoveToTuple(std::move(src)), result);
+ }
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ if (src->type_index != TypeIndex::kTVMFFIArray) return false;
+ const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
+ // check static length first
+ if (n.size_ != Nm) return false;
+ // then check element type
+ return CheckSubTypeAux(std::make_index_sequence<Nm>{}, n);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (!CheckAnyFast(src)) return std::nullopt;
+ try {
+ auto array = CopyFromAnyImpl<ArrayObj>(src);
+ return ConstructTupleAux(std::make_index_sequence<Nm>{}, *array);
+ } catch (const details::STLTypeMismatch&) {
+ return std::nullopt;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ std::ostringstream os;
+ os << "std::tuple<";
+ const char* sep = "";
+ ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
+ os << ">";
+ return std::move(os).str();
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream os;
+ os << R"({"type":"std::tuple","args":[)";
+ const char* sep = "";
+ ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
+ os << "]}";
+ return std::move(os).str();
+ }
+};
+
+template <typename K, typename V>
+struct TypeTraits<std::map<K, V>> : public TypeTraits<details::MapTemplate> {
+ private:
+ using Self = std::map<K, V>;
+ TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFIMap;
+ }
+
+ public:
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return MoveToAnyImpl(CopyToMap(src), result);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return MoveToAnyImpl(MoveToMap(std::move(src)), result);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (!CheckAnyFast(src)) return std::nullopt;
+ try {
+ return ConstructMap<Self, /*CanReserve=*/false>(src);
+ } catch (const details::STLTypeMismatch&) {
+ return std::nullopt;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return "std::map<" + details::Type2Str<K>::v() + ", " +
details::Type2Str<V>::v() + ">";
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return R"({"type":"std::map","args":[)" + details::TypeSchema<K>::v() +
"," +
+ details::TypeSchema<V>::v() + "]}";
+ }
+};
+
+template <typename K, typename V>
+struct TypeTraits<std::unordered_map<K, V>> : public
TypeTraits<details::MapTemplate> {
+ private:
+ using Self = std::unordered_map<K, V>;
+ TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+ return src->type_index == TypeIndex::kTVMFFIMap;
+ }
+
+ public:
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return MoveToAnyImpl(CopyToMap(src), result);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return MoveToAnyImpl(MoveToMap(std::move(src)), result);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (!CheckAnyFast(src)) return std::nullopt;
+ try {
+ return ConstructMap<Self, /*CanReserve=*/true>(src);
+ } catch (const details::STLTypeMismatch&) {
+ return std::nullopt;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return "std::unordered_map<" + details::Type2Str<K>::v() + ", " +
details::Type2Str<V>::v() +
+ ">";
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ return R"({"type":"std::unordered_map","args":[)" +
details::TypeSchema<K>::v() + "," +
+ details::TypeSchema<V>::v() + "]}";
+ }
+};
+
+template <typename Ret, typename... Args>
+struct TypeTraits<std::function<Ret(Args...)>> : TypeTraitsBase {
+ private:
+ using Self = std::function<Ret(Args...)>;
+ using Function = TypedFunction<Ret(Args...)>;
+ using ProxyTrait = TypeTraits<Function>;
+
+ TVM_FFI_INLINE static Self GetFunc(Function&& f) {
+ return [fn = std::move(f)](Args... args) -> Ret { return
fn(std::forward<Args>(args)...); };
+ }
+
+ public:
+ static constexpr int32_t field_static_type_index =
TypeIndex::kTVMFFIFunction;
+ static constexpr bool storage_enabled = false;
+
+ TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result)
{
+ return ProxyTrait::MoveToAny(Function{src}, result);
+ }
+
+ TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+ return ProxyTrait::MoveToAny(Function{std::move(src)}, result);
+ }
+
+ TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
+ auto opt = ProxyTrait::TryCastFromAnyView(src);
+ if (opt.has_value()) {
+ return GetFunc(std::move(*opt));
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ std::ostringstream os;
+ os << "std::function<" << details::Type2Str<Ret>::v() << "(";
+ const char* sep = "";
+ ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
+ os << ")>";
+ return std::move(os).str();
+ }
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream os;
+ os << R"({"type":"std::function","args":[)" <<
details::TypeSchema<Ret>::v() << ",[";
+ const char* sep = "";
+ ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
+ os << "]]}";
+ return std::move(os).str();
+ }
+};
+
+} // namespace ffi
+} // namespace tvm
+
+#endif // TVM_FFI_EXTRA_STL_H_
diff --git a/tests/python/cpp_src/test_stl.cc b/tests/python/cpp_src/test_stl.cc
new file mode 100644
index 0000000..d7ed421
--- /dev/null
+++ b/tests/python/cpp_src/test_stl.cc
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ffi/extra/stl.h>
+#include <tvm/ffi/function.h>
+
+#include <array>
+#include <functional>
+#include <map>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <unordered_map>
+#include <variant>
+#include <vector>
+
+namespace {
+
+auto test_tuple(std::tuple<int, float> arg) -> std::tuple<float, int> {
+ return std::make_tuple(std::get<1>(arg), std::get<0>(arg));
+}
+
+auto test_vector(std::optional<std::vector<std::array<int, 2>>> arg)
+ -> std::optional<std::vector<int>> {
+ if (arg) {
+ auto result = std::vector<int>{};
+ result.reserve(arg->size());
+ for (const auto& row : *arg) {
+ result.push_back(std::accumulate(row.begin(), row.end(), 0));
+ }
+ return result;
+ } else {
+ return std::nullopt;
+ }
+}
+
+auto test_variant(std::variant<int, float, std::vector<std::variant<int,
float>>> arg)
+ -> std::variant<std::string, std::vector<std::string>> {
+ if (std::holds_alternative<int>(arg)) {
+ return "int";
+ } else if (std::holds_alternative<float>(arg)) {
+ return "float";
+ } else {
+ auto result = std::vector<std::string>{};
+ for (const auto& item : std::get<std::vector<std::variant<int,
float>>>(arg)) {
+ if (std::holds_alternative<int>(item)) {
+ result.emplace_back("int");
+ } else {
+ result.emplace_back("float");
+ }
+ }
+ return result;
+ }
+}
+
+auto test_map(const std::map<std::string, int>& map) -> std::map<int,
std::string> {
+ auto result = std::map<int, std::string>{};
+ for (const auto& [key, value] : map) {
+ result[value] = key;
+ }
+ return result;
+}
+
+auto test_map_2(const std::unordered_map<std::string, int>& map)
+ -> std::unordered_map<int, std::string> {
+ auto result = std::unordered_map<int, std::string>{};
+ for (const auto& [key, value] : map) {
+ result[value] = key;
+ }
+ return result;
+}
+
+auto test_function(std::function<int(void)> f) -> std::function<int(void)> {
+ return [fn = std::move(f)] { return fn() + 1; };
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_tuple, test_tuple);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_vector, test_vector);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_variant, test_variant);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_map, test_map);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_map_2, test_map_2);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_function, test_function);
+
+} // namespace
diff --git a/tests/python/test_stl.py b/tests/python/test_stl.py
new file mode 100644
index 0000000..4d87ff2
--- /dev/null
+++ b/tests/python/test_stl.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pathlib
+
+import pytest
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def test_stl() -> None:
+ cpp_path = pathlib.Path(__file__).parent.resolve() / "cpp_src" /
"test_stl.cc"
+ output_lib_path = tvm_ffi.cpp.build(
+ name="test_stl",
+ cpp_files=[str(cpp_path)],
+ )
+
+ mod: Module = tvm_ffi.load_module(output_lib_path)
+
+ assert list(mod.test_tuple([1, 2.5])) == [2.5, 1]
+ assert mod.test_vector(None) == None
+ assert list(mod.test_vector([[1, 2], [3, 4]])) == [3, 7]
+ assert mod.test_variant(1) == "int"
+ assert mod.test_variant(1.0) == "float"
+ assert list(mod.test_variant([1, 1.0])) == ["int", "float"]
+ assert dict(mod.test_map({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+ assert dict(mod.test_map_2({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+ assert mod.test_function(lambda: 0)() == 1
+ assert mod.test_function(lambda: 10)() == 11
+
+ with pytest.raises(TypeError):
+ mod.test_tuple([1.5, 2.5])
+ with pytest.raises(TypeError):
+ mod.test_function(lambda: 0)(100)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])