This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new c3fc8f7  [Feature] Support vanilla C++ STL in FFI. (#228)
c3fc8f7 is described below

commit c3fc8f7f0e95a97beed342b6ddec4c3f6add0441
Author: DarkSharpness <[email protected]>
AuthorDate: Sun Nov 30 08:14:44 2025 +0800

    [Feature] Support vanilla C++ STL in FFI. (#228)
    
    Similar to pybind, we add a `stl.h` which support `array`, `vector`,
    `tuple`, `optional` and `variant`. After this file is included, users
    can use native C++ components, which could hopefully improve
    compatibility and reduce manual effort in converting between tvm::ffi
    components to C++ STL components.
    
    ~~We also modify the `function_detail.h` a little, so that we support
    all kinds of argument type (`T`, `const T`, `T&`, `const T&`, `T&&`,
    `const T&&` have been tested) in C++ exported functions.~~
    
    Example code:
    
    ```cpp
    #include <tvm/ffi/container/array.h>
    #include <tvm/ffi/container/tensor.h>
    #include <tvm/ffi/dtype.h>
    #include <tvm/ffi/error.h>
    #include <tvm/ffi/extra/c_env_api.h>
    #include <tvm/ffi/extra/stl.h>
    #include <tvm/ffi/function.h>
    
    #include <algorithm>
    #include <array>
    #include <cstddef>
    #include <numeric>
    #include <optional>
    #include <variant>
    #include <vector>
    
    namespace {
    
    // optional, array, vector, tuple is supported
    auto sum_row(std::optional<std::vector<std::array<int, 2>>> arg)
        -> std::tuple<bool, std::vector<int>> {
      if (arg) {
        auto result = std::vector<int>{};
        result.reserve(arg->size());
        for (const auto& row : *arg) {
          result.push_back(std::accumulate(row.begin(), row.end(), 0));
        }
        return {true, result};
      } else {
        return {false, {}};
      }
    }
    
    // const reference is also supported, though not recommended and won't 
bring performance gain
    // all types must be cast to value, and then pass by reference
    auto find_diff(const std::vector<int>& a, std::vector<int> b) -> 
std::size_t {
      const auto max_pos = std::min(a.size(), b.size());
      for (std::size_t i = 0; i < max_pos; ++i) {
        if (a[i] != b[i]) {
          return i;
        }
      }
      return max_pos;
    }
    
    auto test_variant(std::variant<int, float, std::vector<int>>&& arg)
        -> std::variant<int, std::vector<int>> {
      if (std::holds_alternative<int>(arg)) {
        std::vector<int> result;
        auto& value = std::get<int>(arg);
        result.reserve(value);
        for (int i = 0; i < value; ++i) {
          result.push_back(i);
        }
        return result;
      } else if (std::holds_alternative<float>(arg)) {
        return static_cast<int>(std::get<float>(arg));
      } else {
        auto& value = std::get<std::vector<int>>(arg);
        std::reverse(value.begin(), value.end());
        return std::move(value);
      }
    }
    
    TVM_FFI_DLL_EXPORT_TYPED_FUNC(sum_row, sum_row);
    TVM_FFI_DLL_EXPORT_TYPED_FUNC(find_diff, find_diff);
    TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_variant, test_variant);
    
    }  // namespace
    
    
    }  // namespace
    
    ```
    
    Python part:
    ```python
    from __future__ import annotations
    
    from tvm_ffi.cpp import load_inline
    from pathlib import Path
    
    cur_path = Path(__file__).parent
    
    with open(cur_path / "stl.cpp") as f:
        cpp_source = f.read()
    
    module = load_inline(
        "test_stl",
        cpp_sources = cpp_source,
    )
    
    print(module.sum_row([[1, 2], [3, 4]]))  # Expected output: (True, [3, 7])
    print(module.sum_row(None))  # Expected output: (False, [])
    print(module.find_diff([1, 2, 3, 4], [1, 2, 4, 3]))  # Expected output: 2 
(index = 2)
    print(module.test_variant(2))  # Expected output: [0, 1]
    print(module.test_variant(3.1))  # Expected output: 3
    print(module.test_variant([1, 2]))  # Expected output: [2, 1]
    ```
---
 include/tvm/ffi/container/map.h  |   3 +
 include/tvm/ffi/extra/stl.h      | 649 +++++++++++++++++++++++++++++++++++++++
 tests/python/cpp_src/test_stl.cc |  99 ++++++
 tests/python/test_stl.py         |  51 +++
 4 files changed, 802 insertions(+)

diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index b948fc8..526b345 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -246,6 +246,9 @@ class MapObj : public Object {
   // Reference class
   template <typename, typename, typename>
   friend class Map;
+
+  template <typename, typename>
+  friend struct TypeTraits;
 };
 
 /*! \brief A specialization of small-sized hash map */
diff --git a/include/tvm/ffi/extra/stl.h b/include/tvm/ffi/extra/stl.h
new file mode 100644
index 0000000..523f77f
--- /dev/null
+++ b/include/tvm/ffi/extra/stl.h
@@ -0,0 +1,649 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/extra/stl.h
+ * \brief STL container support.
+ * \note This file is an extra extension of TVM FFI,
+ * which provides support for STL containers in C++ exported functions.
+ *
+ * Whenever possible, prefer using tvm/ffi/container/ implementations,
+ * such as `tvm::ffi::Array` and `tvm::ffi::Tuple`, over STL containers.
+ *
+ * Native ffi objects comes with stable data layout and can be directly 
accessed
+ * through compiled languages (Rust) and DSLs(via LLVM) with raw pointer access
+ * for better performance and compatibility.
+ */
+#ifndef TVM_FFI_EXTRA_STL_H_
+#define TVM_FFI_EXTRA_STL_H_
+
+#include <tvm/ffi/base_details.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/type_traits.h>
+
+#include <algorithm>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <exception>
+#include <functional>
+#include <iterator>
+#include <map>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "tvm/ffi/function.h"
+
+namespace tvm {
+namespace ffi {
+namespace details {
+
+struct STLTypeMismatch : public std::exception {
+  const char* what() const noexcept override { return "STL type mismatch"; }
+};
+
+struct STLTypeTrait : public TypeTraitsBase {
+ public:
+  static constexpr bool storage_enabled = false;
+
+ protected:
+  /// NOTE: we always copy STL types into an Object first, then move the 
ObjectPtr to Any.
+  template <typename T>
+  TVM_FFI_INLINE static void MoveToAnyImpl(ObjectPtr<T>&& src, TVMFFIAny* 
result) {
+    TVMFFIObject* obj_ptr = 
ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(src));
+    result->type_index = obj_ptr->type_index;
+    result->zero_padding = 0;
+    TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
+    result->v_obj = obj_ptr;
+  }
+
+  /// NOTE: we always construct STL types from an Object first, then copy from 
the ObjectPtr in Any.
+  template <typename T>
+  TVM_FFI_INLINE static ObjectPtr<T> CopyFromAnyImpl(const TVMFFIAny* src) {
+    return ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj);
+  }
+
+  /// NOTE: STL types are not natively movable from Any, so we always make a 
new copy.
+  template <typename T>
+  TVM_FFI_INLINE static T ConstructFromAny(const Any& value) {
+    using TypeTrait = TypeTraits<T>;
+    if constexpr (std::is_same_v<T, Any>) {
+      return value;
+    } else {
+      /// NOTE: Not all type support `CheckArgStrict` (e.g. std::string),
+      /// so we use `TryCast` instead (without any prior check).
+      auto opt = 
TypeTrait::TryCastFromAnyView(AnyUnsafe::TVMFFIAnyPtrFromAny(value));
+      if (!opt.has_value()) {
+        throw STLTypeMismatch{};
+      }
+      return std::move(*opt);
+    }
+  }
+};
+
+struct ListTemplate {};
+struct MapTemplate {};
+
+}  // namespace details
+
+template <>
+struct TypeTraits<details::ListTemplate> : public details::STLTypeTrait {
+ public:
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
+
+ private:
+  template <std::size_t... Is, typename Tuple>
+  TVM_FFI_INLINE static ObjectPtr<ArrayObj> 
CopyToTupleImpl(std::index_sequence<Is...>,
+                                                            Tuple&& src) {
+    auto array = ArrayObj::Empty(static_cast<std::int64_t>(sizeof...(Is)));
+    auto dst = array->MutableBegin();
+    // increase size after each new to ensure exception safety
+    ((::new (dst++) Any(std::get<Is>(std::forward<Tuple>(src))), 
array->size_++), ...);
+    return array;
+  }
+
+  template <typename Iter>
+  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToArrayImpl(Iter src, 
std::size_t size) {
+    auto array = ArrayObj::Empty(static_cast<std::int64_t>(size));
+    auto dst = array->MutableBegin();
+    // increase size after each new to ensure exception safety
+    for (std::size_t i = 0; i < size; ++i) {
+      ::new (dst++) Any(*(src++));
+      array->size_++;
+    }
+    return array;
+  }
+
+ protected:
+  template <typename Tuple>
+  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToTuple(const Tuple& src) {
+    return 
CopyToTupleImpl(std::make_index_sequence<std::tuple_size_v<Tuple>>{}, src);
+  }
+
+  template <typename Tuple>
+  TVM_FFI_INLINE static ObjectPtr<ArrayObj> MoveToTuple(Tuple&& src) {
+    return 
CopyToTupleImpl(std::make_index_sequence<std::tuple_size_v<Tuple>>{},
+                           std::forward<Tuple>(src));
+  }
+
+  template <typename Range>
+  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToArray(const Range& src) {
+    return CopyToArrayImpl(std::begin(src), std::size(src));
+  }
+
+  template <typename Range>
+  TVM_FFI_INLINE static ObjectPtr<ArrayObj> MoveToArray(Range&& src) {
+    return CopyToArrayImpl(std::make_move_iterator(std::begin(src)), 
std::size(src));
+  }
+};
+
+template <>
+struct TypeTraits<details::MapTemplate> : public details::STLTypeTrait {
+ public:
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap;
+
+ protected:
+  template <typename MapType>
+  TVM_FFI_INLINE static ObjectPtr<Object> CopyToMap(const MapType& src) {
+    return MapObj::CreateFromRange(std::begin(src), std::end(src));
+  }
+
+  template <typename MapType>
+  TVM_FFI_INLINE static ObjectPtr<Object> MoveToMap(MapType&& src) {
+    return MapObj::CreateFromRange(std::make_move_iterator(std::begin(src)),
+                                   std::make_move_iterator(std::end(src)));
+  }
+
+  template <typename MapType, bool CanReserve>
+  TVM_FFI_INLINE static MapType ConstructMap(const TVMFFIAny* src) {
+    using KeyType = typename MapType::key_type;
+    using ValueType = typename MapType::mapped_type;
+    auto result = MapType{};
+    auto map_obj = CopyFromAnyImpl<MapObj>(src);
+    if constexpr (CanReserve) {
+      result.reserve(map_obj->size());
+    }
+    for (const auto& [key, value] : *map_obj) {
+      result.try_emplace(ConstructFromAny<KeyType>(key), 
ConstructFromAny<ValueType>(value));
+    }
+    return result;
+  }
+};
+
+template <typename T, std::size_t Nm>
+struct TypeTraits<std::array<T, Nm>> : public 
TypeTraits<details::ListTemplate> {
+ private:
+  using Self = std::array<T, Nm>;
+
+  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
+    const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
+    return n.size_ == Nm;
+  }
+
+ public:
+  static_assert(Nm > 0, "Zero-length std::array is not supported.");
+
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return MoveToAnyImpl(CopyToArray(src), result);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return MoveToAnyImpl(MoveToArray(std::move(src)), result);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (!CheckAnyFast(src)) return std::nullopt;
+    try {
+      auto array = CopyFromAnyImpl<ArrayObj>(src);
+      auto begin = array->MutableBegin();
+      Self result;  // no initialization to avoid overhead
+      for (std::size_t i = 0; i < Nm; ++i) {
+        result[i] = ConstructFromAny<T>(begin[i]);
+      }
+      return result;
+    } catch (const details::STLTypeMismatch&) {
+      return std::nullopt;
+    }
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    return "std::array<" + details::Type2Str<T>::v() + ", " + 
std::to_string(Nm) + ">";
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    return R"({"type":"std::array","args":[)" + details::TypeSchema<T>::v() + 
"," +
+           std::to_string(Nm) + "]}";
+  }
+};
+
+template <typename T>
+struct TypeTraits<std::vector<T>> : public TypeTraits<details::ListTemplate> {
+ private:
+  using Self = std::vector<T>;
+
+  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIArray;
+  }
+
+ public:
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return MoveToAnyImpl(CopyToArray(src), result);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return MoveToAnyImpl(MoveToArray(std::move(src)), result);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (!CheckAnyFast(src)) return std::nullopt;
+    try {
+      auto array = CopyFromAnyImpl<ArrayObj>(src);
+      auto begin = array->MutableBegin();
+      auto result = Self{};
+      auto length = array->size_;
+      result.reserve(length);
+      for (std::size_t i = 0; i < length; ++i) {
+        result.emplace_back(ConstructFromAny<T>(begin[i]));
+      }
+      return result;
+    } catch (const details::STLTypeMismatch&) {
+      return std::nullopt;
+    }
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    return "std::vector<" + details::Type2Str<T>::v() + ">";
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    return R"({"type":"std::vector","args":[)" + details::TypeSchema<T>::v() + 
"]}";
+  }
+};
+
+template <typename T>
+struct TypeTraits<std::optional<T>> : public TypeTraitsBase {
+ public:
+  using Self = std::optional<T>;
+
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    if (src.has_value()) {
+      TypeTraits<T>::CopyToAnyView(*src, result);
+    } else {
+      TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result);
+    }
+  }
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    if (src.has_value()) {
+      TypeTraits<T>::MoveToAny(std::move(*src), result);
+    } else {
+      TypeTraits<std::nullptr_t>::MoveToAny(nullptr, result);
+    }
+  }
+
+  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return true;
+    return TypeTraits<T>::CheckAnyStrict(src);
+  }
+
+  TVM_FFI_INLINE static Self CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
+    return TypeTraits<T>::CopyFromAnyViewAfterCheck(src);
+  }
+
+  TVM_FFI_INLINE static Self MoveFromAnyAfterCheck(TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
+    return TypeTraits<T>::MoveFromAnyAfterCheck(src);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
+    auto result = std::optional<Self>{};
+    if (std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(src)) {
+      /// NOTE: std::optional<T> is just what we want (Self).
+      result.emplace(std::move(opt));
+    } else {
+      result.reset();  // failed to cast, indicate failure
+    }
+    return result;
+  }
+
+  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+    return TypeTraits<T>::GetMismatchTypeInfo(src);
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    return "std::optional<" + TypeTraits<T>::TypeStr() + ">";
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    return R"({"type":"std::optional","args":[)" + details::TypeSchema<T>::v() 
+ "]}";
+  }
+};
+
+template <typename... Args>
+struct TypeTraits<std::variant<Args...>> : public TypeTraitsBase {
+ private:
+  using Self = std::variant<Args...>;
+  static constexpr std::size_t Nm = sizeof...(Args);
+
+  template <std::size_t Is = 0>
+  TVM_FFI_INLINE static Self CopyUnsafeAux(const TVMFFIAny* src) {
+    if constexpr (Is >= Nm) {
+      TVM_FFI_ICHECK(false) << "Unreachable: variant TryCast failed.";
+      throw;  // unreachable
+    } else {
+      using ElemType = std::variant_alternative_t<Is, Self>;
+      if (TypeTraits<ElemType>::CheckAnyStrict(src)) {
+        return Self{std::in_place_index<Is>, 
TypeTraits<ElemType>::CopyFromAnyViewAfterCheck(src)};
+      } else {
+        return CopyUnsafeAux<Is + 1>(src);
+      }
+    }
+  }
+
+  template <std::size_t Is = 0>
+  TVM_FFI_INLINE static Self MoveUnsafeAux(const TVMFFIAny* src) {
+    if constexpr (Is >= Nm) {
+      TVM_FFI_ICHECK(false) << "Unreachable: variant TryCast failed.";
+      throw;  // unreachable
+    } else {
+      using ElemType = std::variant_alternative_t<Is, Self>;
+      if (TypeTraits<ElemType>::CheckAnyStrict(src)) {
+        return Self{std::in_place_index<Is>, 
TypeTraits<ElemType>::MoveFromAnyAfterCheck(src)};
+      } else {
+        return MoveUnsafeAux<Is + 1>(src);
+      }
+    }
+  }
+
+  template <std::size_t Is = 0>
+  TVM_FFI_INLINE static std::optional<Self> TryCastAux(const TVMFFIAny* src) {
+    if constexpr (Is >= Nm) {
+      return std::nullopt;
+    } else {
+      using ElemType = std::variant_alternative_t<Is, Self>;
+      if (auto opt = TypeTraits<ElemType>::TryCastFromAnyView(src)) {
+        return Self{std::in_place_index<Is>, std::move(*opt)};
+      } else {
+        return TryCastAux<Is + 1>(src);
+      }
+    }
+  }
+
+ public:
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return std::visit(
+        [&](const auto& value) {
+          using ValueType = std::decay_t<decltype(value)>;
+          TypeTraits<ValueType>::CopyToAnyView(value, result);
+        },
+        src);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return std::visit(
+        [&](auto&& value) {
+          using ValueType = std::decay_t<decltype(value)>;
+          TypeTraits<ValueType>::MoveToAny(std::forward<ValueType>(value), 
result);
+        },
+        std::move(src));
+  }
+
+  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+    return (TypeTraits<Args>::CheckAnyStrict(src) || ...);
+  }
+
+  TVM_FFI_INLINE static Self CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    // find the first possible type to copy
+    return CopyUnsafeAux(src);
+  }
+
+  TVM_FFI_INLINE static Self MoveFromAnyAfterCheck(TVMFFIAny* src) {
+    // find the first possible type to move
+    return MoveUnsafeAux(src);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    // try to find the first possible type to copy
+    return TryCastAux(src);
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    std::ostringstream os;
+    os << "std::variant<";
+    const char* sep = "";
+    ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
+    os << ">";
+    return std::move(os).str();
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    std::ostringstream os;
+    os << R"({"type":"std::variant","args":[)";
+    const char* sep = "";
+    ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
+    os << "]}";
+    return std::move(os).str();
+  }
+};
+
+template <typename... Args>
+struct TypeTraits<std::tuple<Args...>> : public 
TypeTraits<details::ListTemplate> {
+ private:
+  using Self = std::tuple<Args...>;
+  static constexpr std::size_t Nm = sizeof...(Args);
+  static_assert(Nm > 0, "Zero-length std::tuple is not supported.");
+
+  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
+    const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
+    return n.size_ == Nm;
+  }
+
+  template <std::size_t... Is>
+  TVM_FFI_INLINE static Self ConstructTupleAux(std::index_sequence<Is...>, 
const ArrayObj& n) {
+    return Self{ConstructFromAny<std::tuple_element_t<Is, Self>>(n[Is])...};
+  }
+
+ public:
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
+
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return MoveToAnyImpl(CopyToTuple(src), result);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return MoveToAnyImpl(MoveToTuple(std::move(src)), result);
+  }
+
+  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
+    const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
+    // check static length first
+    if (n.size_ != Nm) return false;
+    // then check element type
+    return CheckSubTypeAux(std::make_index_sequence<Nm>{}, n);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (!CheckAnyFast(src)) return std::nullopt;
+    try {
+      auto array = CopyFromAnyImpl<ArrayObj>(src);
+      return ConstructTupleAux(std::make_index_sequence<Nm>{}, *array);
+    } catch (const details::STLTypeMismatch&) {
+      return std::nullopt;
+    }
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    std::ostringstream os;
+    os << "std::tuple<";
+    const char* sep = "";
+    ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
+    os << ">";
+    return std::move(os).str();
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    std::ostringstream os;
+    os << R"({"type":"std::tuple","args":[)";
+    const char* sep = "";
+    ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
+    os << "]}";
+    return std::move(os).str();
+  }
+};
+
+template <typename K, typename V>
+struct TypeTraits<std::map<K, V>> : public TypeTraits<details::MapTemplate> {
+ private:
+  using Self = std::map<K, V>;
+  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIMap;
+  }
+
+ public:
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return MoveToAnyImpl(CopyToMap(src), result);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return MoveToAnyImpl(MoveToMap(std::move(src)), result);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (!CheckAnyFast(src)) return std::nullopt;
+    try {
+      return ConstructMap<Self, /*CanReserve=*/false>(src);
+    } catch (const details::STLTypeMismatch&) {
+      return std::nullopt;
+    }
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    return "std::map<" + details::Type2Str<K>::v() + ", " + 
details::Type2Str<V>::v() + ">";
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    return R"({"type":"std::map","args":[)" + details::TypeSchema<K>::v() + 
"," +
+           details::TypeSchema<V>::v() + "]}";
+  }
+};
+
+template <typename K, typename V>
+struct TypeTraits<std::unordered_map<K, V>> : public 
TypeTraits<details::MapTemplate> {
+ private:
+  using Self = std::unordered_map<K, V>;
+  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
+    return src->type_index == TypeIndex::kTVMFFIMap;
+  }
+
+ public:
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return MoveToAnyImpl(CopyToMap(src), result);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return MoveToAnyImpl(MoveToMap(std::move(src)), result);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    if (!CheckAnyFast(src)) return std::nullopt;
+    try {
+      return ConstructMap<Self, /*CanReserve=*/true>(src);
+    } catch (const details::STLTypeMismatch&) {
+      return std::nullopt;
+    }
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    return "std::unordered_map<" + details::Type2Str<K>::v() + ", " + 
details::Type2Str<V>::v() +
+           ">";
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    return R"({"type":"std::unordered_map","args":[)" + 
details::TypeSchema<K>::v() + "," +
+           details::TypeSchema<V>::v() + "]}";
+  }
+};
+
+template <typename Ret, typename... Args>
+struct TypeTraits<std::function<Ret(Args...)>> : TypeTraitsBase {
+ private:
+  using Self = std::function<Ret(Args...)>;
+  using Function = TypedFunction<Ret(Args...)>;
+  using ProxyTrait = TypeTraits<Function>;
+
+  TVM_FFI_INLINE static Self GetFunc(Function&& f) {
+    return [fn = std::move(f)](Args... args) -> Ret { return 
fn(std::forward<Args>(args)...); };
+  }
+
+ public:
+  static constexpr int32_t field_static_type_index = 
TypeIndex::kTVMFFIFunction;
+  static constexpr bool storage_enabled = false;
+
+  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) 
{
+    return ProxyTrait::MoveToAny(Function{src}, result);
+  }
+
+  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
+    return ProxyTrait::MoveToAny(Function{std::move(src)}, result);
+  }
+
+  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const 
TVMFFIAny* src) {
+    auto opt = ProxyTrait::TryCastFromAnyView(src);
+    if (opt.has_value()) {
+      return GetFunc(std::move(*opt));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  TVM_FFI_INLINE static std::string TypeStr() {
+    std::ostringstream os;
+    os << "std::function<" << details::Type2Str<Ret>::v() << "(";
+    const char* sep = "";
+    ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
+    os << ")>";
+    return std::move(os).str();
+  }
+
+  TVM_FFI_INLINE static std::string TypeSchema() {
+    std::ostringstream os;
+    os << R"({"type":"std::function","args":[)" << 
details::TypeSchema<Ret>::v() << ",[";
+    const char* sep = "";
+    ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
+    os << "]]}";
+    return std::move(os).str();
+  }
+};
+
+}  // namespace ffi
+}  // namespace tvm
+
+#endif  // TVM_FFI_EXTRA_STL_H_
diff --git a/tests/python/cpp_src/test_stl.cc b/tests/python/cpp_src/test_stl.cc
new file mode 100644
index 0000000..d7ed421
--- /dev/null
+++ b/tests/python/cpp_src/test_stl.cc
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ffi/extra/stl.h>
+#include <tvm/ffi/function.h>
+
+#include <array>
+#include <functional>
+#include <map>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <unordered_map>
+#include <variant>
+#include <vector>
+
+namespace {
+
+auto test_tuple(std::tuple<int, float> arg) -> std::tuple<float, int> {
+  return std::make_tuple(std::get<1>(arg), std::get<0>(arg));
+}
+
+auto test_vector(std::optional<std::vector<std::array<int, 2>>> arg)
+    -> std::optional<std::vector<int>> {
+  if (arg) {
+    auto result = std::vector<int>{};
+    result.reserve(arg->size());
+    for (const auto& row : *arg) {
+      result.push_back(std::accumulate(row.begin(), row.end(), 0));
+    }
+    return result;
+  } else {
+    return std::nullopt;
+  }
+}
+
+auto test_variant(std::variant<int, float, std::vector<std::variant<int, 
float>>> arg)
+    -> std::variant<std::string, std::vector<std::string>> {
+  if (std::holds_alternative<int>(arg)) {
+    return "int";
+  } else if (std::holds_alternative<float>(arg)) {
+    return "float";
+  } else {
+    auto result = std::vector<std::string>{};
+    for (const auto& item : std::get<std::vector<std::variant<int, 
float>>>(arg)) {
+      if (std::holds_alternative<int>(item)) {
+        result.emplace_back("int");
+      } else {
+        result.emplace_back("float");
+      }
+    }
+    return result;
+  }
+}
+
+auto test_map(const std::map<std::string, int>& map) -> std::map<int, 
std::string> {
+  auto result = std::map<int, std::string>{};
+  for (const auto& [key, value] : map) {
+    result[value] = key;
+  }
+  return result;
+}
+
+auto test_map_2(const std::unordered_map<std::string, int>& map)
+    -> std::unordered_map<int, std::string> {
+  auto result = std::unordered_map<int, std::string>{};
+  for (const auto& [key, value] : map) {
+    result[value] = key;
+  }
+  return result;
+}
+
+auto test_function(std::function<int(void)> f) -> std::function<int(void)> {
+  return [fn = std::move(f)] { return fn() + 1; };
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_tuple, test_tuple);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_vector, test_vector);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_variant, test_variant);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_map, test_map);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_map_2, test_map_2);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_function, test_function);
+
+}  // namespace
diff --git a/tests/python/test_stl.py b/tests/python/test_stl.py
new file mode 100644
index 0000000..4d87ff2
--- /dev/null
+++ b/tests/python/test_stl.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pathlib
+
+import pytest
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def test_stl() -> None:
+    cpp_path = pathlib.Path(__file__).parent.resolve() / "cpp_src" / 
"test_stl.cc"
+    output_lib_path = tvm_ffi.cpp.build(
+        name="test_stl",
+        cpp_files=[str(cpp_path)],
+    )
+
+    mod: Module = tvm_ffi.load_module(output_lib_path)
+
+    assert list(mod.test_tuple([1, 2.5])) == [2.5, 1]
+    assert mod.test_vector(None) == None
+    assert list(mod.test_vector([[1, 2], [3, 4]])) == [3, 7]
+    assert mod.test_variant(1) == "int"
+    assert mod.test_variant(1.0) == "float"
+    assert list(mod.test_variant([1, 1.0])) == ["int", "float"]
+    assert dict(mod.test_map({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+    assert dict(mod.test_map_2({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+    assert mod.test_function(lambda: 0)() == 1
+    assert mod.test_function(lambda: 10)() == 11
+
+    with pytest.raises(TypeError):
+        mod.test_tuple([1.5, 2.5])
+    with pytest.raises(TypeError):
+        mod.test_function(lambda: 0)(100)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])

Reply via email to