This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 3076b1f26039d36b05f4ae13c5f4400df0cc0b7f Author: tqchen <[email protected]> AuthorDate: Mon Sep 9 11:38:46 2024 -0400 [FFI] Introduce String support --- ffi/include/tvm/ffi/any.h | 21 ++ ffi/include/tvm/ffi/base_details.h | 72 ++++++ ffi/include/tvm/ffi/endian.h | 89 ++++++++ ffi/include/tvm/ffi/object.h | 24 +- ffi/include/tvm/ffi/string.h | 436 +++++++++++++++++++++++++++++++++++++ ffi/tests/example/test_string.cc | 265 ++++++++++++++++++++++ 6 files changed, 904 insertions(+), 3 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 3fad36eacc..031038d4fc 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -24,6 +24,7 @@ #define TVM_FFI_ANY_H_ #include <tvm/ffi/c_api.h> +#include <tvm/ffi/string.h> #include <tvm/ffi/type_traits.h> #include <string> @@ -271,6 +272,26 @@ struct AnyUnsafe : public ObjectUnsafe { } }; } // namespace details + +// Downcast an object +// NOTE: the implementation is put in here to avoid cyclic dependency +// with the +template <typename SubRef, typename BaseRef, + typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>> +TVM_FFI_INLINE SubRef Downcast(BaseRef ref) { + if (ref.defined()) { + if (!ref->template IsInstance<typename SubRef::ContainerType>()) { + TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } + } else { + if (!SubRef::_type_is_nullable) { + TVM_FFI_THROW(TypeError) << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; + } + } + return details::ObjectUnsafe::DowncastRefNoCheck<SubRef>(std::move(ref)); +} } // namespace ffi } // namespace tvm #endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h index c44ad0608a..c817a6d817 100644 --- a/ffi/include/tvm/ffi/base_details.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -26,6 +26,7 @@ #define TVM_FFI_BASE_DETAILS_H_ #include <tvm/ffi/c_api.h> +#include <tvm/ffi/endian.h> #include <cstddef> #include <utility> @@ -133,6 +134,77 @@ void for_each(const F& f, Args&&... args) { // NOLINT(*) for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...); } +/*! + * \brief Hash the binary bytes + * \param data The data pointer + * \param size The size of the bytes. + * \return the hash value. + */ +TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) { + const constexpr uint64_t kMultiplier = 1099511628211ULL; + const constexpr uint64_t kMod = 2147483647ULL; + union Union { + uint8_t a[8]; + uint64_t b; + } u; + static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); + const char* it = data; + const char* end = it + size; + uint64_t result = 0; + for (; it + 8 <= end; it += 8) { + if (TVM_FFI_IO_NO_ENDIAN_SWAP) { + u.a[0] = it[0]; + u.a[1] = it[1]; + u.a[2] = it[2]; + u.a[3] = it[3]; + u.a[4] = it[4]; + u.a[5] = it[5]; + u.a[6] = it[6]; + u.a[7] = it[7]; + } else { + u.a[0] = it[7]; + u.a[1] = it[6]; + u.a[2] = it[5]; + u.a[3] = it[4]; + u.a[4] = it[3]; + u.a[5] = it[2]; + u.a[6] = it[1]; + u.a[7] = it[0]; + } + result = (result * kMultiplier + u.b) % kMod; + } + if (it < end) { + u.b = 0; + uint8_t* a = u.a; + if (it + 4 <= end) { + a[0] = it[0]; + a[1] = it[1]; + a[2] = it[2]; + a[3] = it[3]; + it += 4; + a += 4; + } + if (it + 2 <= end) { + a[0] = it[0]; + a[1] = it[1]; + it += 2; + a += 2; + } + if (it + 1 <= end) { + a[0] = it[0]; + it += 1; + a += 1; + } + if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + std::swap(u.a[0], u.a[7]); + std::swap(u.a[1], u.a[6]); + std::swap(u.a[2], u.a[5]); + std::swap(u.a[3], u.a[4]); + } + result = (result * kMultiplier + u.b) % kMod; + } + return result; +} } // namespace details } // namespace ffi } // namespace tvm diff --git a/ffi/include/tvm/ffi/endian.h b/ffi/include/tvm/ffi/endian.h new file mode 100644 index 0000000000..4a73b82e6c --- /dev/null +++ b/ffi/include/tvm/ffi/endian.h @@ -0,0 +1,89 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/endian.h + * \brief Endian detection and handling + */ +#ifndef TVM_FFI_ENDIAN_H_ +#define TVM_FFI_ENDIAN_H_ + +#include <cstddef> +#include <cstdint> + +#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN +#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1 +#endif + +#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN +// If compiled with CMake, use CMake's endian detection logic +#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN +#else +#if defined(__APPLE__) || defined(_WIN32) +#define TVM_FFI_LITTLE_ENDIAN 1 +#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || defined(__RISCV__) +#include <endian.h> +#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) +#elif defined(__FreeBSD__) || defined(__OpenBSD__) +#include <sys/endian.h> +#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN) +#elif defined(__QNX__) +#include <sys/param.h> +#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN) +#elif defined(__EMSCRIPTEN__) || defined(__hexagon__) +#define TVM_FFI_LITTLE_ENDIAN 1 +#elif defined(__sun) || defined(sun) +#include <sys/isa_defs.h> +#if defined(_LITTLE_ENDIAN) +#define TVM_FFI_LITTLE_ENDIAN 1 +#else +#define TVM_FFI_LITTLE_ENDIAN 0 +#endif +#else +#error "Unable to determine endianness of your machine; use CMake to compile" +#endif +#endif + +/*! \brief whether serialize using little endian */ +#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == TVM_FFI_IO_USE_LITTLE_ENDIAN) + +namespace tvm { +namespace ffi { +/*! + * \brief A generic inplace byte swapping function. + * \param data The data pointer. + * \param elem_bytes The number of bytes of the data elements + * \param num_elems Number of elements in the data. + * \note Always try pass in constant elem_bytes to enable + * compiler optimization + */ +inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + uint8_t* bptr = reinterpret_cast<uint8_t*>(data) + elem_bytes * i; + for (size_t j = 0; j < elem_bytes / 2; ++j) { + uint8_t v = bptr[elem_bytes - 1 - j]; + bptr[elem_bytes - 1 - j] = bptr[j]; + bptr[j] = v; + } + } +} +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ENDIAN_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 9e218899d7..120a735c82 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -26,6 +26,7 @@ #include <tvm/ffi/base_details.h> #include <tvm/ffi/c_api.h> +#include <string> #include <type_traits> #include <utility> @@ -147,8 +148,22 @@ class Object { return details::IsObjectInstance<TargetType>(header_.type_index); } + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { +#if TVM_FFI_ALLOW_DYN_TYPE + // the function checks that the info exists + const TypeInfo* type_info = details::ObjectGetTypeInfo(header_.type_index); + return type_info->type_key; +#else + return "<unknown>"; +#endif + } + // Information about the object - static constexpr const char* _type_key = "runtime.Object"; + static constexpr const char* _type_key = "object.Object"; // Default object type properties for sub-classes static constexpr bool _type_final = false; @@ -418,8 +433,6 @@ class ObjectRef { // friend classes. friend struct ObjectPtrHash; friend class tvm::ffi::details::ObjectUnsafe; - template <typename SubRef, typename BaseRef> - friend SubRef Downcast(BaseRef ref); }; /*! @@ -621,6 +634,11 @@ struct ObjectUnsafe { return GetHeader(obj_ptr); } + template <typename SubRef, typename BaseRef> + static TVM_FFI_INLINE SubRef DowncastRefNoCheck(BaseRef&& base) { + return SubRef(std::move(base.data_)); + } + // Create objectptr by moving from an existing address of object and setting its // address to nullptr template <typename T> diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h new file mode 100644 index 0000000000..6b43829ff8 --- /dev/null +++ b/ffi/include/tvm/ffi/string.h @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/string.h + * \brief Runtime String type. + */ +#ifndef TVM_FFI_STRING_H_ +#define TVM_FFI_STRING_H_ + +#include <tvm/ffi/base_details.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/memory.h> +#include <tvm/ffi/object.h> + +#include <cstddef> +#include <cstring> +#include <initializer_list> +#include <string> +#include <string_view> +#include <type_traits> +#include <utility> + +// NOTE: We place string in tvm/ffi instead of tvm/ffi/container +// because string itself needs special handling and is an inherent +// core component for return string handling. +// The following dependency relation holds +// containers -> any -> string -> object + +namespace tvm { +namespace ffi { + +/*! \brief An object representing string. It's POD type. */ +class StringObj : public Object { + public: + /*! \brief The pointer to string data. */ + const char* data; + + /*! \brief The length of the string object. */ + uint64_t size; + + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; + static constexpr const char* _type_key = "object.String"; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object); +}; + +namespace details { + +// String moved from std::string +// without having to trigger a copy +class StringObjStdImpl : public StringObj { + public: + explicit StringObjStdImpl(std::string other) : data_{other} { + this->data = data_.data(); + this->size = data_.size(); + } + + private: + std::string data_; +}; + +// inplace string allocation +TVM_FFI_INLINE ObjectPtr<StringObj> MakeInplaceString(const char* data, size_t length) { + ObjectPtr<StringObj> p = make_inplace_array_object<StringObj, char>(length + 1); + static_assert(alignof(StringObj) % alignof(char) == 0); + static_assert(sizeof(StringObj) % alignof(char) == 0); + char* dest_data = reinterpret_cast<char*>(p.get()) + sizeof(StringObj); + p->data = dest_data; + p->size = length; + std::memcpy(dest_data, data, length); + dest_data[length] = '\0'; + return p; +} +} // namespace details + +/*! + * \brief Reference to string objects. + * + * \code + * + * // Example to create runtime String reference object from std::string + * std::string s = "hello world"; + * + * // You can create the reference from existing std::string + * String ref{std::move(s)}; + * + * // You can rebind the reference to another string. + * ref = std::string{"hello world2"}; + * + * // You can use the reference as hash map key + * std::unordered_map<String, int32_t> m; + * m[ref] = 1; + * + * // You can compare the reference object with other string objects + * assert(ref == "hello world", true); + * + * // You can convert the reference to std::string again + * string s2 = (string)ref; + * + * \endcode + */ +class String : public ObjectRef { + public: + /*! + * \brief constructor from char [N] + * + * \param other a char array. + */ + template <size_t N> + String(const char other[N]) // NOLINT(*) + : ObjectRef(details::MakeInplaceString(other, N)) {} + + /*! + * \brief constructor + */ + String() : String("") {} + + /*! + * \brief constructor from raw string + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : ObjectRef(details::MakeInplaceString(other, std::strlen(other))) {} + + /*! + * \brief Construct a new string object + * \param other The std::string object to be copied + */ + String(const std::string& other) // NOLINT(*) + : ObjectRef(details::MakeInplaceString(other.data(), other.size())) {} + + /*! + * \brief Construct a new string object + * \param other The std::string object to be moved + */ + String(std::string&& other) // NOLINT(*) + : ObjectRef(make_object<details::StringObjStdImpl>(std::move(other))) {} + + /*! + * \brief Swap this String with another string + * \param other The other string + */ + void swap(String& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + template <typename T> + String& operator=(T&& other) { + // copy-and-swap idiom + String(std::forward<T>(other)).swap(*this); + return *this; + } + + /*! + * \brief Compares this String object to other + * + * \param other The String to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this String object to other + * + * \param other The string to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this to other + * + * \param other The character array to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char* other) const { + return memncmp(data(), other, size(), std::strlen(other)); + } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char* c_str() const { return get()->data; } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { + const auto* ptr = get(); + return ptr->size; + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t length() const { return size(); } + + /*! + * \brief Retun if the string is empty + * + * \return true if empty, false otherwise. + */ + bool empty() const { return size() == 0; } + + /*! + * \brief Read an element. + * \param pos The position at which to read the character. + * + * \return The char at position + */ + char at(size_t pos) const { + if (pos < size()) { + return data()[pos]; + } else { + throw std::out_of_range("tvm::String index out of bounds"); + } + } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char* data() const { return get()->data; } + + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + + private: + /*! + * \brief Compare two char sequence + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * \return int zero if both char sequences compare equal. negative if this + * appear before other, positive otherwise. + */ + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + + /*! + * \brief Concatenate two char sequences + * + * \param lhs Pointers to the lhs char array + * \param lhs_size The size of the lhs char array + * \param rhs Pointers to the rhs char array + * \param rhs_size The size of the rhs char array + * + * \return The concatenated char sequence + */ + static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { + std::string ret(lhs, lhs_size); + ret.append(rhs, rhs_size); + return String(ret); + } + + // Overload + operator + friend String operator+(const String& lhs, const String& rhs); + friend String operator+(const String& lhs, const std::string& rhs); + friend String operator+(const std::string& lhs, const String& rhs); + friend String operator+(const String& lhs, const char* rhs); + friend String operator+(const char* lhs, const String& rhs); + + friend struct AnyEqual; +}; + +inline String operator+(const String& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const std::string& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const std::string& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const char* lhs, const String& rhs) { + size_t lhs_size = std::strlen(lhs); + size_t rhs_size = rhs.size(); + return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const char* rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = std::strlen(rhs); + return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); +} + +// Overload < operator +inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +// Overload > operator +inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +// Overload <= operator +inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +// Overload >= operator +inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +// Overload == operator +inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +// Overload != operator +inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline std::ostream& operator<<(std::ostream& out, const String& input) { + out.write(input.data(), input.size()); + return out; +} + +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) return 0; + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) return -1; + if (lhs[i] > rhs[i]) return 1; + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } +} +} // namespace ffi +} // namespace tvm + +namespace std { + +template <> +struct hash<::tvm::ffi::String> { + std::size_t operator()(const ::tvm::ffi::String& str) const { + return ::tvm::ffi::details::StableHashBytes(str.data(), str.size()); + } +}; +} // namespace std +#endif // TVM_FFI_STRING_H_ diff --git a/ffi/tests/example/test_string.cc b/ffi/tests/example/test_string.cc new file mode 100644 index 0000000000..23310824a6 --- /dev/null +++ b/ffi/tests/example/test_string.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/any.h> +#include <tvm/ffi/string.h> + +namespace { + +using namespace tvm::ffi; + +TEST(String, MoveFromStd) { + using namespace std; + string source = "this is a string"; + string expect = source; + String s(std::move(source)); + string copy = (string)s; + EXPECT_EQ(copy, expect); + EXPECT_EQ(source.size(), 0); +} + +TEST(String, CopyFromStd) { + using namespace std; + string source = "this is a string"; + string expect = source; + String s{source}; + string copy = (string)s; + EXPECT_EQ(copy, expect); + EXPECT_EQ(source.size(), expect.size()); +} + +TEST(String, Assignment) { + using namespace std; + String s{string{"hello"}}; + s = string{"world"}; + EXPECT_EQ(s == "world", true); + string s2{"world2"}; + s = std::move(s2); + EXPECT_EQ(s == "world2", true); +} + +TEST(String, empty) { + using namespace std; + String s{"hello"}; + EXPECT_EQ(s.empty(), false); + s = std::string(""); + EXPECT_EQ(s.empty(), true); +} + +TEST(String, Comparisons) { + using namespace std; + string source = "a string"; + string mismatch = "a string but longer"; + String s{"a string"}; + String m{mismatch}; + + EXPECT_EQ("a str" >= s, false); + EXPECT_EQ(s == source, true); + EXPECT_EQ(s == mismatch, false); + EXPECT_EQ(s == source.data(), true); + EXPECT_EQ(s == mismatch.data(), false); + + EXPECT_EQ(s < m, source < mismatch); + EXPECT_EQ(s > m, source > mismatch); + EXPECT_EQ(s <= m, source <= mismatch); + EXPECT_EQ(s >= m, source >= mismatch); + EXPECT_EQ(s == m, source == mismatch); + EXPECT_EQ(s != m, source != mismatch); + + EXPECT_EQ(m < s, mismatch < source); + EXPECT_EQ(m > s, mismatch > source); + EXPECT_EQ(m <= s, mismatch <= source); + EXPECT_EQ(m >= s, mismatch >= source); + EXPECT_EQ(m == s, mismatch == source); + EXPECT_EQ(m != s, mismatch != source); +} + +// Check '\0' handling +TEST(String, null_byte_handling) { + using namespace std; + // Ensure string still compares equal if it contains '\0'. + string v1 = "hello world"; + size_t v1_size = v1.size(); + v1[5] = '\0'; + EXPECT_EQ(v1[5], '\0'); + EXPECT_EQ(v1.size(), v1_size); + String str_v1{v1}; + EXPECT_EQ(str_v1.compare(v1), 0); + EXPECT_EQ(str_v1.size(), v1_size); + + // Ensure bytes after '\0' are taken into account for mismatches. + string v2 = "aaa one"; + string v3 = "aaa two"; + v2[3] = '\0'; + v3[3] = '\0'; + String str_v2{v2}; + String str_v3{v3}; + EXPECT_EQ(str_v2.compare(str_v3), -1); + EXPECT_EQ(str_v2.size(), 7); + // strcmp won't be able to detect the mismatch + EXPECT_EQ(strcmp(v2.data(), v3.data()), 0); + // string::compare can handle \0 since it knows size + EXPECT_LT(v2.compare(v3), 0); + + // If there is mismatch before '\0', should still handle it. + string v4 = "acc one"; + string v5 = "abb two"; + v4[3] = '\0'; + v5[3] = '\0'; + String str_v4{v4}; + String str_v5{v5}; + EXPECT_GT(str_v4.compare(str_v5), 0); + EXPECT_EQ(str_v4.size(), 7); + // strcmp is able to detect the mismatch + EXPECT_GT(strcmp(v4.data(), v5.data()), 0); + // string::compare can handle \0 since it knows size + EXPECT_GT(v4.compare(v5), 0); +} + +TEST(String, compare_same_memory_region_different_size) { + using namespace std; + string source = "a string"; + String str_source{source}; + char* memory = const_cast<char*>(str_source.data()); + EXPECT_EQ(str_source.compare(memory), 0); + // This changes the string size + memory[2] = '\0'; + // memory is logically shorter now + EXPECT_GT(str_source.compare(memory), 0); +} + +TEST(String, compare) { + using namespace std; + constexpr auto mismatch1_cstr = "a string but longer"; + string source = "a string"; + string mismatch1 = mismatch1_cstr; + string mismatch2 = "a strin"; + string mismatch3 = "a b"; + string mismatch4 = "a t"; + String str_source{source}; + String str_mismatch1{mismatch1_cstr}; + String str_mismatch2{mismatch2}; + String str_mismatch3{mismatch3}; + String str_mismatch4{mismatch4}; + + // compare with string + EXPECT_EQ(str_source.compare(source), 0); + EXPECT_TRUE(str_source == source); + EXPECT_TRUE(source == str_source); + EXPECT_TRUE(str_source <= source); + EXPECT_TRUE(source <= str_source); + EXPECT_TRUE(str_source >= source); + EXPECT_TRUE(source >= str_source); + EXPECT_LT(str_source.compare(mismatch1), 0); + EXPECT_TRUE(str_source < mismatch1); + EXPECT_TRUE(mismatch1 != str_source); + EXPECT_GT(str_source.compare(mismatch2), 0); + EXPECT_TRUE(str_source > mismatch2); + EXPECT_TRUE(mismatch2 < str_source); + EXPECT_GT(str_source.compare(mismatch3), 0); + EXPECT_TRUE(str_source > mismatch3); + EXPECT_LT(str_source.compare(mismatch4), 0); + EXPECT_TRUE(str_source < mismatch4); + EXPECT_TRUE(mismatch4 > str_source); + + // compare with char* + EXPECT_EQ(str_source.compare(source.data()), 0); + EXPECT_TRUE(str_source == source.data()); + EXPECT_TRUE(source.data() == str_source); + EXPECT_TRUE(str_source <= source.data()); + EXPECT_TRUE(source <= str_source.data()); + EXPECT_TRUE(str_source >= source.data()); + EXPECT_TRUE(source >= str_source.data()); + EXPECT_LT(str_source.compare(mismatch1.data()), 0); + EXPECT_TRUE(str_source < mismatch1.data()); + EXPECT_TRUE(str_source != mismatch1.data()); + EXPECT_TRUE(mismatch1.data() != str_source); + EXPECT_GT(str_source.compare(mismatch2.data()), 0); + EXPECT_TRUE(str_source > mismatch2.data()); + EXPECT_TRUE(mismatch2.data() < str_source); + EXPECT_GT(str_source.compare(mismatch3.data()), 0); + EXPECT_TRUE(str_source > mismatch3.data()); + EXPECT_LT(str_source.compare(mismatch4.data()), 0); + EXPECT_TRUE(str_source < mismatch4.data()); + EXPECT_TRUE(mismatch4.data() > str_source); + + // compare with String + EXPECT_LT(str_source.compare(str_mismatch1), 0); + EXPECT_TRUE(str_source < str_mismatch1); + EXPECT_GT(str_source.compare(str_mismatch2), 0); + EXPECT_TRUE(str_source > str_mismatch2); + EXPECT_GT(str_source.compare(str_mismatch3), 0); + EXPECT_TRUE(str_source > str_mismatch3); + EXPECT_LT(str_source.compare(str_mismatch4), 0); + EXPECT_TRUE(str_source < str_mismatch4); +} + +TEST(String, c_str) { + using namespace std; + string source = "this is a string"; + string mismatch = "mismatch"; + String s{source}; + + EXPECT_EQ(std::strcmp(s.c_str(), source.data()), 0); + EXPECT_NE(std::strcmp(s.c_str(), mismatch.data()), 0); +} + +TEST(String, hash) { + using namespace std; + string source = "this is a string"; + String s{source}; + std::hash<String>()(s); + + std::unordered_map<String, std::string> map; + String k1{string{"k1"}}; + string v1{"v1"}; + String k2{string{"k2"}}; + string v2{"v2"}; + map[k1] = v1; + map[k2] = v2; + + EXPECT_EQ(map[k1], v1); + EXPECT_EQ(map[k2], v2); +} + +TEST(String, Cast) { + using namespace std; + string source = "this is a string"; + String s{source}; + ObjectRef r = s; + String s2 = Downcast<String>(r); +} + +TEST(String, Concat) { + String s1("hello"); + String s2("world"); + std::string s3("world"); + String res1 = s1 + s2; + String res2 = s1 + s3; + String res3 = s3 + s1; + String res4 = s1 + "world"; + String res5 = "world" + s1; + + EXPECT_EQ(res1.compare("helloworld"), 0); + EXPECT_EQ(res2.compare("helloworld"), 0); + EXPECT_EQ(res3.compare("worldhello"), 0); + EXPECT_EQ(res4.compare("helloworld"), 0); + EXPECT_EQ(res5.compare("worldhello"), 0); +} +} // namespace \ No newline at end of file
