This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit d5b942f94e934e0d4cd61de669bafed17cbababe Author: tqchen <[email protected]> AuthorDate: Fri Sep 13 15:20:58 2024 -0400 [FFI] Optional Support --- ffi/include/tvm/ffi/any.h | 17 +- ffi/include/tvm/ffi/container/optional.h | 300 +++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/object.h | 2 +- ffi/tests/example/test_any.cc | 22 +-- ffi/tests/example/test_optional.cc | 107 +++++++++++ 5 files changed, 431 insertions(+), 17 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 031038d4fc..b7181b2297 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -112,12 +112,12 @@ class AnyView { * \return The underlying supporting data of any view * \note This function is used only for testing purposes. */ - TVMFFIAny AsTVMFFIAny() const { return data_; } + TVMFFIAny CopyToTVMFFIAny() const { return data_; } /*! * \return Create an AnyView from TVMFFIAny * \param data the underlying ffi data. */ - static AnyView FromTVMFFIAny(TVMFFIAny data) { + static AnyView CopyFromTVMFFIAny(TVMFFIAny data) { AnyView view; view.data_ = data; return view; @@ -142,7 +142,10 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, } // namespace details /*! - * \brief + * \brief Managed Any that takes strong reference to a value. + * + * \note Develooper invariance: the TVMFFIAny data_ + * in the Any can be safely used in AnyView. */ class Any { protected: @@ -198,7 +201,7 @@ class Any { return *this; } /*! \brief Any can be converted to AnyView in zero cost. */ - operator AnyView() { return AnyView::FromTVMFFIAny(data_); } + operator AnyView() { return AnyView::CopyFromTVMFFIAny(data_); } // constructor from general types template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> Any(T other) { // NOLINT(*) @@ -227,10 +230,14 @@ class Any { TVM_FFI_UNREACHABLE(); } + bool operator==(std::nullptr_t) const { return data_.type_index == TypeIndex::kTVMFFINone; } + + bool operator!=(std::nullptr_t) const { return data_.type_index != TypeIndex::kTVMFFINone; } + // FFI related operations /*! * Move the current data to FFI any - * \parma result the output to nmove to + * \param result the output to nmove to */ void MoveToTVMFFIAny(TVMFFIAny* result) { *result = data_; diff --git a/ffi/include/tvm/ffi/container/optional.h b/ffi/include/tvm/ffi/container/optional.h new file mode 100644 index 0000000000..5f08db0a5b --- /dev/null +++ b/ffi/include/tvm/ffi/container/optional.h @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/optional.h + * \brief Runtime Optional container types. + */ +#ifndef TVM_FFI_CONTAINER_OPTIONAL_H_ +#define TVM_FFI_CONTAINER_OPTIONAL_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/object.h> + +#include <optional> + +namespace tvm { +namespace ffi { + +/*! + * \brief Optional that is backed by Any + * + * nullptr will be treated as NullOpt + * + * \tparam T any value will be treated as + */ +template <typename T> +class Optional<T, std::enable_if_t<!std::is_base_of_v<ObjectRef, T>>> { + public: + static_assert(!std::is_same_v<std::nullptr_t, T>, "Optional<nullptr> is not well defined"); + // default constructors. + Optional() = default; + Optional(const Optional<T>& other) : data_(other.data_) {} + Optional(Optional<T>&& other) : data_(std::move(other.data_)) {} + Optional<T>& operator=(const Optional<T>& other) { + data_ = other.data_; + return *this; + } + Optional<T>& operator=(Optional<T>&& other) { + data_ = std::move(other.data_); + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : data_(std::move(other)) {} + Optional<T>& operator=(T other) { + data_ = std::move(other); + return *this; + } + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional<T>& operator=(std::nullptr_t) { + data_ = std::nullopt; + return *this; + } + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + if (!data_.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return *data_; + } + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value_or(T default_value) const { return data_.value_or(default_value); } + + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { return data_.has_value(); } + + bool has_value() const { return data_.has_value(); } + + bool operator==(const Optional<T>& other) const { return data_ == other.data_; } + + bool operator!=(const Optional<T>& other) const { return data_ != other.data_; } + + template <typename U> + bool operator==(const U& other) const { + return data_ == other; + } + + template <typename U> + bool operator!=(const U& other) const { + return data_ != other; + } + + // operator overloadings with nullptr + bool operator==(std::nullptr_t) const { return !data_.has_value(); } + bool operator!=(std::nullptr_t) const { return data_.has_value(); } + + // helper function to move out value + T&& MoveValueNoCheck() { return std::move(*data_); } + // helper function to copy out value + T CopyValueNoCheck() const { return *data_; } + + private: + std::optional<T> data_; +}; + +/*! + * \brief Specialization of Optional for ObjectRef. + * + * In such cases, nullptr is treated as NullOpt. + * This specialization reduces the storage cost of + * Optional for ObjectRef. + * + * \tparam T The original ObjectRef. + */ +template <typename T> +class Optional<T, std::enable_if_t<std::is_base_of_v<ObjectRef, T>>> : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of<ObjectRef, T>::value, "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional<T>& other) : ObjectRef(other.data_) {} + Optional(Optional<T>&& other) : ObjectRef(std::move(other.data_)) {} + Optional<T>& operator=(const Optional<T>& other) { + data_ = other.data_; + return *this; + } + Optional<T>& operator=(Optional<T>&& other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + Optional(std::nullopt_t) {} // NOLINT(*) + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional<T>& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + Optional<T>& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional<Integer>(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + explicit Optional(int val) = delete; + Optional<T>& operator=(int val) = delete; + // helper function to move out value + T&& MoveOutValueNoCheck() { return T(std::move(data_)); } + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Bad optional access"; + } + return T(data_); + } + /*! + * \return The internal object pointer with container type of T. + * \note This function do not perform not-null checking. + */ + const ContainerType* get() const { return static_cast<ContainerType*>(data_.get()); } + /*! + * \return The contained value if the Optional is not null + * otherwise return the default_value. + */ + T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } + + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { return *this != nullptr; } + /*! \return Whether the container is not nullptr */ + bool has_value() const { return *this != nullptr; } + + // operator overloadings + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + auto operator==(const Optional<T>& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional<T>& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (*this != nullptr && other != nullptr) { + return value() != other.value(); + } else { + // one of them is nullptr. + return RetType(true); + } + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { return !(*this == other); } + template <typename U> + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template <typename U> + auto operator!=(const U& other) const { + using RetType = decltype(value() != other); + if (*this == nullptr) return RetType(true); + return value() != other; + } + static constexpr bool _type_is_nullable = true; + + // helper function to move out value + T&& MoveValueNoCheck() { return T(std::move(data_)); } + // helper function to copy out value + T CopyValueNoCheck() const { return T(data_); } +}; + +template <typename T> +inline constexpr bool use_default_type_traits_v<Optional<T>> = false; + +template <typename T> +struct TypeTraits<Optional<T>> : public TypeTraitsBase { + static TVM_FFI_INLINE void CopyToAnyView(const Optional<T>& src, TVMFFIAny* result) { + if (src.has_value()) { + TypeTraits<T>::CopyToAnyView(src.CopyValueNoCheck(), result); + } else { + TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result); + } + } + + static TVM_FFI_INLINE void MoveToAny(Optional<T> src, TVMFFIAny* result) { + if (src.has_value()) { + TypeTraits<T>::MoveToAny(src.MoveValueNoCheck(), result); + } else { + TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result); + } + } + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + return TypeTraits<T>::GetMismatchTypeInfo(src); + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) return true; + return TypeTraits<T>::CheckAnyView(src); + } + + static TVM_FFI_INLINE Optional<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) return Optional<T>(nullptr); + return TypeTraits<T>::CopyFromAnyViewAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional<Optional<T>> TryCopyFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) return Optional<T>(nullptr); + return TypeTraits<T>::TryCopyFromAnyView(src); + } + + static TVM_FFI_INLINE std::string TypeStr() { + return "Optional<" + TypeTraits<T>::TypeStr() + ">"; + } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_OPTIONAL_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 120a735c82..07170dce8f 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -352,7 +352,7 @@ class ObjectPtr { }; // Forward declaration, to prevent circular includes. -template <typename T> +template <typename T, typename = void> class Optional; /*! \brief Base class of all object reference */ diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc index 02d0ad6a23..36d4783dc5 100644 --- a/ffi/tests/example/test_any.cc +++ b/ffi/tests/example/test_any.cc @@ -29,7 +29,7 @@ using namespace tvm::ffi::testing; TEST(Any, Int) { AnyView view0; - EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); std::optional<int64_t> opt_v0 = view0.TryAs<int64_t>(); EXPECT_TRUE(!opt_v0.has_value()); @@ -48,21 +48,21 @@ TEST(Any, Int) { ::tvm::ffi::Error); AnyView view1 = 1; - EXPECT_EQ(view1.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view1.AsTVMFFIAny().v_int64, 1); + EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); int32_t int_v1 = view1; EXPECT_EQ(int_v1, 1); int64_t v1 = 2; view0 = v1; - EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view0.AsTVMFFIAny().v_int64, 2); + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2); } TEST(Any, Float) { AnyView view0; - EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); std::optional<double> opt_v0 = view0.TryAs<double>(); EXPECT_TRUE(!opt_v0.has_value()); @@ -85,18 +85,18 @@ TEST(Any, Float) { EXPECT_EQ(float_v1, 1); AnyView view2 = 2.2; - EXPECT_EQ(view2.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); - EXPECT_EQ(view2.AsTVMFFIAny().v_float64, 2.2); + EXPECT_EQ(view2.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); + EXPECT_EQ(view2.CopyToTVMFFIAny().v_float64, 2.2); float v1 = 2; view0 = v1; - EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); - EXPECT_EQ(view0.AsTVMFFIAny().v_float64, 2); + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); + EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2); } TEST(Any, Object) { AnyView view0; - EXPECT_EQ(view0.AsTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); // int object is not nullable std::optional<TInt> opt_v0 = view0.TryAs<TInt>(); diff --git a/ffi/tests/example/test_optional.cc b/ffi/tests/example/test_optional.cc new file mode 100644 index 0000000000..10d9189c23 --- /dev/null +++ b/ffi/tests/example/test_optional.cc @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/optional.h> +#include <tvm/ffi/memory.h> + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Optional, TInt) { + Optional<TInt> x; + Optional<TInt> y = TInt(11); + static_assert(sizeof(Optional<TInt>) == sizeof(ObjectRef)); + + EXPECT_TRUE(!x.has_value()); + EXPECT_TRUE(x == nullptr); + EXPECT_EQ(x.value_or(TInt(12))->value, 12); + + EXPECT_TRUE(y.has_value()); + EXPECT_TRUE(y != nullptr); + EXPECT_EQ(y.value_or(TInt(12))->value, 11); +} + +TEST(Optional, double) { + Optional<double> x; + Optional<double> y = 11.0; + static_assert(sizeof(Optional<double>) > sizeof(ObjectRef)); + + EXPECT_TRUE(!x.has_value()); + EXPECT_TRUE(x == nullptr); + EXPECT_EQ(x.value_or(12), 12); + EXPECT_TRUE(x != 12); + + EXPECT_TRUE(y.has_value()); + EXPECT_TRUE(y != nullptr); + EXPECT_EQ(y.value_or(12), 11); + EXPECT_TRUE(y == 11); + EXPECT_TRUE(y != 12); +} + +TEST(Optional, AnyConvert_int) { + Optional<int> opt_v0 = 1; + EXPECT_EQ(opt_v0.value(), 1); + EXPECT_TRUE(opt_v0 != nullptr); + + AnyView view0 = opt_v0; + EXPECT_EQ(view0.operator int(), 1); + + Any any1; + Optional<int> opt_v1 = any1; + + EXPECT_TRUE(opt_v1 == nullptr); +} + +TEST(Optional, AnyConvert_Array) { + AnyView view0; + Array<Array<TNumber>> arr_nested = {{}, {TInt(1), TFloat(2)}}; + view0 = arr_nested; + + Optional<Array<Array<TNumber>>> opt_arr = view0; + EXPECT_EQ(arr_nested.use_count(), 2); + + Optional<Array<Array<TNumber>>> arr1 = view0; + EXPECT_EQ(arr_nested.use_count(), 3); + EXPECT_EQ(arr1.value()[1][1].as<TFloatObj>()->value, 2); + + Any any1; + Optional<Array<Array<TNumber>>> arr2 = any1; + EXPECT_TRUE(arr2 == nullptr); + + EXPECT_THROW( + { + try { + [[maybe_unused]] Optional<Array<Array<int>>> arr2 = view0; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("to `Optional<Array<Array<int>>>`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +} // namespace
