This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3076b1f26039d36b05f4ae13c5f4400df0cc0b7f
Author: tqchen <[email protected]>
AuthorDate: Mon Sep 9 11:38:46 2024 -0400

    [FFI] Introduce String support
---
 ffi/include/tvm/ffi/any.h          |  21 ++
 ffi/include/tvm/ffi/base_details.h |  72 ++++++
 ffi/include/tvm/ffi/endian.h       |  89 ++++++++
 ffi/include/tvm/ffi/object.h       |  24 +-
 ffi/include/tvm/ffi/string.h       | 436 +++++++++++++++++++++++++++++++++++++
 ffi/tests/example/test_string.cc   | 265 ++++++++++++++++++++++
 6 files changed, 904 insertions(+), 3 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 3fad36eacc..031038d4fc 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -24,6 +24,7 @@
 #define TVM_FFI_ANY_H_
 
 #include <tvm/ffi/c_api.h>
+#include <tvm/ffi/string.h>
 #include <tvm/ffi/type_traits.h>
 
 #include <string>
@@ -271,6 +272,26 @@ struct AnyUnsafe : public ObjectUnsafe {
   }
 };
 }  // namespace details
+
+// Downcast an object
+// NOTE: the implementation is put in here to avoid cyclic dependency
+// with the
+template <typename SubRef, typename BaseRef,
+          typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>>
+TVM_FFI_INLINE SubRef Downcast(BaseRef ref) {
+  if (ref.defined()) {
+    if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
+      TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " 
to "
+                               << SubRef::ContainerType::_type_key << " 
failed.";
+    }
+  } else {
+    if (!SubRef::_type_is_nullable) {
+      TVM_FFI_THROW(TypeError) << "Downcast from nullptr to not nullable 
reference of "
+                               << SubRef::ContainerType::_type_key;
+    }
+  }
+  return details::ObjectUnsafe::DowncastRefNoCheck<SubRef>(std::move(ref));
+}
 }  // namespace ffi
 }  // namespace tvm
 #endif  // TVM_FFI_ANY_H_
diff --git a/ffi/include/tvm/ffi/base_details.h 
b/ffi/include/tvm/ffi/base_details.h
index c44ad0608a..c817a6d817 100644
--- a/ffi/include/tvm/ffi/base_details.h
+++ b/ffi/include/tvm/ffi/base_details.h
@@ -26,6 +26,7 @@
 #define TVM_FFI_BASE_DETAILS_H_
 
 #include <tvm/ffi/c_api.h>
+#include <tvm/ffi/endian.h>
 
 #include <cstddef>
 #include <utility>
@@ -133,6 +134,77 @@ void for_each(const F& f, Args&&... args) {  // NOLINT(*)
   for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, 
std::forward<Args>(args)...);
 }
 
+/*!
+ * \brief Hash the binary bytes
+ * \param data The data pointer
+ * \param size The size of the bytes.
+ * \return the hash value.
+ */
+TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
+  const constexpr uint64_t kMultiplier = 1099511628211ULL;
+  const constexpr uint64_t kMod = 2147483647ULL;
+  union Union {
+    uint8_t a[8];
+    uint64_t b;
+  } u;
+  static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != 
sizeof(uint64_t)");
+  const char* it = data;
+  const char* end = it + size;
+  uint64_t result = 0;
+  for (; it + 8 <= end; it += 8) {
+    if (TVM_FFI_IO_NO_ENDIAN_SWAP) {
+      u.a[0] = it[0];
+      u.a[1] = it[1];
+      u.a[2] = it[2];
+      u.a[3] = it[3];
+      u.a[4] = it[4];
+      u.a[5] = it[5];
+      u.a[6] = it[6];
+      u.a[7] = it[7];
+    } else {
+      u.a[0] = it[7];
+      u.a[1] = it[6];
+      u.a[2] = it[5];
+      u.a[3] = it[4];
+      u.a[4] = it[3];
+      u.a[5] = it[2];
+      u.a[6] = it[1];
+      u.a[7] = it[0];
+    }
+    result = (result * kMultiplier + u.b) % kMod;
+  }
+  if (it < end) {
+    u.b = 0;
+    uint8_t* a = u.a;
+    if (it + 4 <= end) {
+      a[0] = it[0];
+      a[1] = it[1];
+      a[2] = it[2];
+      a[3] = it[3];
+      it += 4;
+      a += 4;
+    }
+    if (it + 2 <= end) {
+      a[0] = it[0];
+      a[1] = it[1];
+      it += 2;
+      a += 2;
+    }
+    if (it + 1 <= end) {
+      a[0] = it[0];
+      it += 1;
+      a += 1;
+    }
+    if (!TVM_FFI_IO_NO_ENDIAN_SWAP) {
+      std::swap(u.a[0], u.a[7]);
+      std::swap(u.a[1], u.a[6]);
+      std::swap(u.a[2], u.a[5]);
+      std::swap(u.a[3], u.a[4]);
+    }
+    result = (result * kMultiplier + u.b) % kMod;
+  }
+  return result;
+}
 }  // namespace details
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/include/tvm/ffi/endian.h b/ffi/include/tvm/ffi/endian.h
new file mode 100644
index 0000000000..4a73b82e6c
--- /dev/null
+++ b/ffi/include/tvm/ffi/endian.h
@@ -0,0 +1,89 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file tvm/ffi/endian.h
+ * \brief Endian detection and handling
+ */
+#ifndef TVM_FFI_ENDIAN_H_
+#define TVM_FFI_ENDIAN_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN
+#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1
+#endif
+
+#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN
+// If compiled with CMake, use CMake's endian detection logic
+#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN
+#else
+#if defined(__APPLE__) || defined(_WIN32)
+#define TVM_FFI_LITTLE_ENDIAN 1
+#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) 
|| defined(__RISCV__)
+#include <endian.h>
+#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN)
+#elif defined(__FreeBSD__) || defined(__OpenBSD__)
+#include <sys/endian.h>
+#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN)
+#elif defined(__QNX__)
+#include <sys/param.h>
+#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN)
+#elif defined(__EMSCRIPTEN__) || defined(__hexagon__)
+#define TVM_FFI_LITTLE_ENDIAN 1
+#elif defined(__sun) || defined(sun)
+#include <sys/isa_defs.h>
+#if defined(_LITTLE_ENDIAN)
+#define TVM_FFI_LITTLE_ENDIAN 1
+#else
+#define TVM_FFI_LITTLE_ENDIAN 0
+#endif
+#else
+#error "Unable to determine endianness of your machine; use CMake to compile"
+#endif
+#endif
+
+/*! \brief whether serialize using little endian */
+#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == 
TVM_FFI_IO_USE_LITTLE_ENDIAN)
+
+namespace tvm {
+namespace ffi {
+/*!
+ * \brief A generic inplace byte swapping function.
+ * \param data The data pointer.
+ * \param elem_bytes The number of bytes of the data elements
+ * \param num_elems Number of elements in the data.
+ * \note Always try pass in constant elem_bytes to enable
+ *       compiler optimization
+ */
+inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) {
+  for (size_t i = 0; i < num_elems; ++i) {
+    uint8_t* bptr = reinterpret_cast<uint8_t*>(data) + elem_bytes * i;
+    for (size_t j = 0; j < elem_bytes / 2; ++j) {
+      uint8_t v = bptr[elem_bytes - 1 - j];
+      bptr[elem_bytes - 1 - j] = bptr[j];
+      bptr[j] = v;
+    }
+  }
+}
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_ENDIAN_H_
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 9e218899d7..120a735c82 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -26,6 +26,7 @@
 #include <tvm/ffi/base_details.h>
 #include <tvm/ffi/c_api.h>
 
+#include <string>
 #include <type_traits>
 #include <utility>
 
@@ -147,8 +148,22 @@ class Object {
     return details::IsObjectInstance<TargetType>(header_.type_index);
   }
 
+  /*!
+   * \return the type key of the object.
+   * \note this operation is expensive, can be used for error reporting.
+   */
+  std::string GetTypeKey() const {
+#if TVM_FFI_ALLOW_DYN_TYPE
+    // the function checks that the info exists
+    const TypeInfo* type_info = details::ObjectGetTypeInfo(header_.type_index);
+    return type_info->type_key;
+#else
+    return "<unknown>";
+#endif
+  }
+
   // Information about the object
-  static constexpr const char* _type_key = "runtime.Object";
+  static constexpr const char* _type_key = "object.Object";
 
   // Default object type properties for sub-classes
   static constexpr bool _type_final = false;
@@ -418,8 +433,6 @@ class ObjectRef {
   // friend classes.
   friend struct ObjectPtrHash;
   friend class tvm::ffi::details::ObjectUnsafe;
-  template <typename SubRef, typename BaseRef>
-  friend SubRef Downcast(BaseRef ref);
 };
 
 /*!
@@ -621,6 +634,11 @@ struct ObjectUnsafe {
     return GetHeader(obj_ptr);
   }
 
+  template <typename SubRef, typename BaseRef>
+  static TVM_FFI_INLINE SubRef DowncastRefNoCheck(BaseRef&& base) {
+    return SubRef(std::move(base.data_));
+  }
+
   // Create objectptr by moving from an existing address of object and setting 
its
   // address to nullptr
   template <typename T>
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
new file mode 100644
index 0000000000..6b43829ff8
--- /dev/null
+++ b/ffi/include/tvm/ffi/string.h
@@ -0,0 +1,436 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/string.h
+ * \brief Runtime String type.
+ */
+#ifndef TVM_FFI_STRING_H_
+#define TVM_FFI_STRING_H_
+
+#include <tvm/ffi/base_details.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/memory.h>
+#include <tvm/ffi/object.h>
+
+#include <cstddef>
+#include <cstring>
+#include <initializer_list>
+#include <string>
+#include <string_view>
+#include <type_traits>
+#include <utility>
+
+// NOTE: We place string in tvm/ffi instead of tvm/ffi/container
+// because string itself needs special handling and is an inherent
+// core component for return string handling.
+// The following dependency relation holds
+// containers -> any -> string -> object
+
+namespace tvm {
+namespace ffi {
+
+/*! \brief An object representing string. It's POD type. */
+class StringObj : public Object {
+ public:
+  /*! \brief The pointer to string data. */
+  const char* data;
+
+  /*! \brief The length of the string object. */
+  uint64_t size;
+
+  static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr;
+  static constexpr const char* _type_key = "object.String";
+  static const constexpr bool _type_final = true;
+  TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object);
+};
+
+namespace details {
+
+// String moved from std::string
+// without having to trigger a copy
+class StringObjStdImpl : public StringObj {
+ public:
+  explicit StringObjStdImpl(std::string other) : data_{other} {
+    this->data = data_.data();
+    this->size = data_.size();
+  }
+
+ private:
+  std::string data_;
+};
+
+// inplace string allocation
+TVM_FFI_INLINE ObjectPtr<StringObj> MakeInplaceString(const char* data, size_t 
length) {
+  ObjectPtr<StringObj> p = make_inplace_array_object<StringObj, char>(length + 
1);
+  static_assert(alignof(StringObj) % alignof(char) == 0);
+  static_assert(sizeof(StringObj) % alignof(char) == 0);
+  char* dest_data = reinterpret_cast<char*>(p.get()) + sizeof(StringObj);
+  p->data = dest_data;
+  p->size = length;
+  std::memcpy(dest_data, data, length);
+  dest_data[length] = '\0';
+  return p;
+}
+}  // namespace details
+
+/*!
+ * \brief Reference to string objects.
+ *
+ * \code
+ *
+ * // Example to create runtime String reference object from std::string
+ * std::string s = "hello world";
+ *
+ * // You can create the reference from existing std::string
+ * String ref{std::move(s)};
+ *
+ * // You can rebind the reference to another string.
+ * ref = std::string{"hello world2"};
+ *
+ * // You can use the reference as hash map key
+ * std::unordered_map<String, int32_t> m;
+ * m[ref] = 1;
+ *
+ * // You can compare the reference object with other string objects
+ * assert(ref == "hello world", true);
+ *
+ * // You can convert the reference to std::string again
+ * string s2 = (string)ref;
+ *
+ * \endcode
+ */
+class String : public ObjectRef {
+ public:
+  /*!
+   * \brief constructor from char [N]
+   *
+   * \param other a char array.
+   */
+  template <size_t N>
+  String(const char other[N])  // NOLINT(*)
+      : ObjectRef(details::MakeInplaceString(other, N)) {}
+
+  /*!
+   * \brief constructor
+   */
+  String() : String("") {}
+
+  /*!
+   * \brief constructor from raw string
+   *
+   * \param other a char array.
+   */
+  String(const char* other)  // NOLINT(*)
+      : ObjectRef(details::MakeInplaceString(other, std::strlen(other))) {}
+
+  /*!
+   * \brief Construct a new string object
+   * \param other The std::string object to be copied
+   */
+  String(const std::string& other)  // NOLINT(*)
+      : ObjectRef(details::MakeInplaceString(other.data(), other.size())) {}
+
+  /*!
+   * \brief Construct a new string object
+   * \param other The std::string object to be moved
+   */
+  String(std::string&& other)  // NOLINT(*)
+      : ObjectRef(make_object<details::StringObjStdImpl>(std::move(other))) {}
+
+  /*!
+   * \brief Swap this String with another string
+   * \param other The other string
+   */
+  void swap(String& other) {  // NOLINT(*)
+    std::swap(data_, other.data_);
+  }
+
+  template <typename T>
+  String& operator=(T&& other) {
+    // copy-and-swap idiom
+    String(std::forward<T>(other)).swap(*this);
+    return *this;
+  }
+
+  /*!
+   * \brief Compares this String object to other
+   *
+   * \param other The String to compare with.
+   *
+   * \return zero if both char sequences compare equal. negative if this appear
+   * before other, positive otherwise.
+   */
+  int compare(const String& other) const {
+    return memncmp(data(), other.data(), size(), other.size());
+  }
+
+  /*!
+   * \brief Compares this String object to other
+   *
+   * \param other The string to compare with.
+   *
+   * \return zero if both char sequences compare equal. negative if this appear
+   * before other, positive otherwise.
+   */
+  int compare(const std::string& other) const {
+    return memncmp(data(), other.data(), size(), other.size());
+  }
+
+  /*!
+   * \brief Compares this to other
+   *
+   * \param other The character array to compare with.
+   *
+   * \return zero if both char sequences compare equal. negative if this appear
+   * before other, positive otherwise.
+   */
+  int compare(const char* other) const {
+    return memncmp(data(), other, size(), std::strlen(other));
+  }
+
+  /*!
+   * \brief Returns a pointer to the char array in the string.
+   *
+   * \return const char*
+   */
+  const char* c_str() const { return get()->data; }
+
+  /*!
+   * \brief Return the length of the string
+   *
+   * \return size_t string length
+   */
+  size_t size() const {
+    const auto* ptr = get();
+    return ptr->size;
+  }
+
+  /*!
+   * \brief Return the length of the string
+   *
+   * \return size_t string length
+   */
+  size_t length() const { return size(); }
+
+  /*!
+   * \brief Retun if the string is empty
+   *
+   * \return true if empty, false otherwise.
+   */
+  bool empty() const { return size() == 0; }
+
+  /*!
+   * \brief Read an element.
+   * \param pos The position at which to read the character.
+   *
+   * \return The char at position
+   */
+  char at(size_t pos) const {
+    if (pos < size()) {
+      return data()[pos];
+    } else {
+      throw std::out_of_range("tvm::String index out of bounds");
+    }
+  }
+
+  /*!
+   * \brief Return the data pointer
+   *
+   * \return const char* data pointer
+   */
+  const char* data() const { return get()->data; }
+
+  /*!
+   * \brief Convert String to an std::string object
+   *
+   * \return std::string
+   */
+  operator std::string() const { return std::string{get()->data, size()}; }
+
+  TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
+
+ private:
+  /*!
+   * \brief Compare two char sequence
+   *
+   * \param lhs Pointers to the char array to compare
+   * \param rhs Pointers to the char array to compare
+   * \param lhs_count Length of the char array to compare
+   * \param rhs_count Length of the char array to compare
+   * \return int zero if both char sequences compare equal. negative if this
+   * appear before other, positive otherwise.
+   */
+  static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, 
size_t rhs_count);
+
+  /*!
+   * \brief Concatenate two char sequences
+   *
+   * \param lhs Pointers to the lhs char array
+   * \param lhs_size The size of the lhs char array
+   * \param rhs Pointers to the rhs char array
+   * \param rhs_size The size of the rhs char array
+   *
+   * \return The concatenated char sequence
+   */
+  static String Concat(const char* lhs, size_t lhs_size, const char* rhs, 
size_t rhs_size) {
+    std::string ret(lhs, lhs_size);
+    ret.append(rhs, rhs_size);
+    return String(ret);
+  }
+
+  // Overload + operator
+  friend String operator+(const String& lhs, const String& rhs);
+  friend String operator+(const String& lhs, const std::string& rhs);
+  friend String operator+(const std::string& lhs, const String& rhs);
+  friend String operator+(const String& lhs, const char* rhs);
+  friend String operator+(const char* lhs, const String& rhs);
+
+  friend struct AnyEqual;
+};
+
+inline String operator+(const String& lhs, const String& rhs) {
+  size_t lhs_size = lhs.size();
+  size_t rhs_size = rhs.size();
+  return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
+}
+
+inline String operator+(const String& lhs, const std::string& rhs) {
+  size_t lhs_size = lhs.size();
+  size_t rhs_size = rhs.size();
+  return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
+}
+
+inline String operator+(const std::string& lhs, const String& rhs) {
+  size_t lhs_size = lhs.size();
+  size_t rhs_size = rhs.size();
+  return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
+}
+
+inline String operator+(const char* lhs, const String& rhs) {
+  size_t lhs_size = std::strlen(lhs);
+  size_t rhs_size = rhs.size();
+  return String::Concat(lhs, lhs_size, rhs.data(), rhs_size);
+}
+
+inline String operator+(const String& lhs, const char* rhs) {
+  size_t lhs_size = lhs.size();
+  size_t rhs_size = std::strlen(rhs);
+  return String::Concat(lhs.data(), lhs_size, rhs, rhs_size);
+}
+
+// Overload < operator
+inline bool operator<(const String& lhs, const std::string& rhs) { return 
lhs.compare(rhs) < 0; }
+
+inline bool operator<(const std::string& lhs, const String& rhs) { return 
rhs.compare(lhs) > 0; }
+
+inline bool operator<(const String& lhs, const String& rhs) { return 
lhs.compare(rhs) < 0; }
+
+inline bool operator<(const String& lhs, const char* rhs) { return 
lhs.compare(rhs) < 0; }
+
+inline bool operator<(const char* lhs, const String& rhs) { return 
rhs.compare(lhs) > 0; }
+
+// Overload > operator
+inline bool operator>(const String& lhs, const std::string& rhs) { return 
lhs.compare(rhs) > 0; }
+
+inline bool operator>(const std::string& lhs, const String& rhs) { return 
rhs.compare(lhs) < 0; }
+
+inline bool operator>(const String& lhs, const String& rhs) { return 
lhs.compare(rhs) > 0; }
+
+inline bool operator>(const String& lhs, const char* rhs) { return 
lhs.compare(rhs) > 0; }
+
+inline bool operator>(const char* lhs, const String& rhs) { return 
rhs.compare(lhs) < 0; }
+
+// Overload <= operator
+inline bool operator<=(const String& lhs, const std::string& rhs) { return 
lhs.compare(rhs) <= 0; }
+
+inline bool operator<=(const std::string& lhs, const String& rhs) { return 
rhs.compare(lhs) >= 0; }
+
+inline bool operator<=(const String& lhs, const String& rhs) { return 
lhs.compare(rhs) <= 0; }
+
+inline bool operator<=(const String& lhs, const char* rhs) { return 
lhs.compare(rhs) <= 0; }
+
+inline bool operator<=(const char* lhs, const String& rhs) { return 
rhs.compare(lhs) >= 0; }
+
+// Overload >= operator
+inline bool operator>=(const String& lhs, const std::string& rhs) { return 
lhs.compare(rhs) >= 0; }
+
+inline bool operator>=(const std::string& lhs, const String& rhs) { return 
rhs.compare(lhs) <= 0; }
+
+inline bool operator>=(const String& lhs, const String& rhs) { return 
lhs.compare(rhs) >= 0; }
+
+inline bool operator>=(const String& lhs, const char* rhs) { return 
lhs.compare(rhs) >= 0; }
+
+inline bool operator>=(const char* lhs, const String& rhs) { return 
rhs.compare(lhs) <= 0; }
+
+// Overload == operator
+inline bool operator==(const String& lhs, const std::string& rhs) { return 
lhs.compare(rhs) == 0; }
+
+inline bool operator==(const std::string& lhs, const String& rhs) { return 
rhs.compare(lhs) == 0; }
+
+inline bool operator==(const String& lhs, const String& rhs) { return 
lhs.compare(rhs) == 0; }
+
+inline bool operator==(const String& lhs, const char* rhs) { return 
lhs.compare(rhs) == 0; }
+
+inline bool operator==(const char* lhs, const String& rhs) { return 
rhs.compare(lhs) == 0; }
+
+// Overload != operator
+inline bool operator!=(const String& lhs, const std::string& rhs) { return 
lhs.compare(rhs) != 0; }
+
+inline bool operator!=(const std::string& lhs, const String& rhs) { return 
rhs.compare(lhs) != 0; }
+
+inline bool operator!=(const String& lhs, const String& rhs) { return 
lhs.compare(rhs) != 0; }
+
+inline bool operator!=(const String& lhs, const char* rhs) { return 
lhs.compare(rhs) != 0; }
+
+inline bool operator!=(const char* lhs, const String& rhs) { return 
rhs.compare(lhs) != 0; }
+
+inline std::ostream& operator<<(std::ostream& out, const String& input) {
+  out.write(input.data(), input.size());
+  return out;
+}
+
+inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, 
size_t rhs_count) {
+  if (lhs == rhs && lhs_count == rhs_count) return 0;
+
+  for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
+    if (lhs[i] < rhs[i]) return -1;
+    if (lhs[i] > rhs[i]) return 1;
+  }
+  if (lhs_count < rhs_count) {
+    return -1;
+  } else if (lhs_count > rhs_count) {
+    return 1;
+  } else {
+    return 0;
+  }
+}
+}  // namespace ffi
+}  // namespace tvm
+
+namespace std {
+
+template <>
+struct hash<::tvm::ffi::String> {
+  std::size_t operator()(const ::tvm::ffi::String& str) const {
+    return ::tvm::ffi::details::StableHashBytes(str.data(), str.size());
+  }
+};
+}  // namespace std
+#endif  // TVM_FFI_STRING_H_
diff --git a/ffi/tests/example/test_string.cc b/ffi/tests/example/test_string.cc
new file mode 100644
index 0000000000..23310824a6
--- /dev/null
+++ b/ffi/tests/example/test_string.cc
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/string.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+TEST(String, MoveFromStd) {
+  using namespace std;
+  string source = "this is a string";
+  string expect = source;
+  String s(std::move(source));
+  string copy = (string)s;
+  EXPECT_EQ(copy, expect);
+  EXPECT_EQ(source.size(), 0);
+}
+
+TEST(String, CopyFromStd) {
+  using namespace std;
+  string source = "this is a string";
+  string expect = source;
+  String s{source};
+  string copy = (string)s;
+  EXPECT_EQ(copy, expect);
+  EXPECT_EQ(source.size(), expect.size());
+}
+
+TEST(String, Assignment) {
+  using namespace std;
+  String s{string{"hello"}};
+  s = string{"world"};
+  EXPECT_EQ(s == "world", true);
+  string s2{"world2"};
+  s = std::move(s2);
+  EXPECT_EQ(s == "world2", true);
+}
+
+TEST(String, empty) {
+  using namespace std;
+  String s{"hello"};
+  EXPECT_EQ(s.empty(), false);
+  s = std::string("");
+  EXPECT_EQ(s.empty(), true);
+}
+
+TEST(String, Comparisons) {
+  using namespace std;
+  string source = "a string";
+  string mismatch = "a string but longer";
+  String s{"a string"};
+  String m{mismatch};
+
+  EXPECT_EQ("a str" >= s, false);
+  EXPECT_EQ(s == source, true);
+  EXPECT_EQ(s == mismatch, false);
+  EXPECT_EQ(s == source.data(), true);
+  EXPECT_EQ(s == mismatch.data(), false);
+
+  EXPECT_EQ(s < m, source < mismatch);
+  EXPECT_EQ(s > m, source > mismatch);
+  EXPECT_EQ(s <= m, source <= mismatch);
+  EXPECT_EQ(s >= m, source >= mismatch);
+  EXPECT_EQ(s == m, source == mismatch);
+  EXPECT_EQ(s != m, source != mismatch);
+
+  EXPECT_EQ(m < s, mismatch < source);
+  EXPECT_EQ(m > s, mismatch > source);
+  EXPECT_EQ(m <= s, mismatch <= source);
+  EXPECT_EQ(m >= s, mismatch >= source);
+  EXPECT_EQ(m == s, mismatch == source);
+  EXPECT_EQ(m != s, mismatch != source);
+}
+
+// Check '\0' handling
+TEST(String, null_byte_handling) {
+  using namespace std;
+  // Ensure string still compares equal if it contains '\0'.
+  string v1 = "hello world";
+  size_t v1_size = v1.size();
+  v1[5] = '\0';
+  EXPECT_EQ(v1[5], '\0');
+  EXPECT_EQ(v1.size(), v1_size);
+  String str_v1{v1};
+  EXPECT_EQ(str_v1.compare(v1), 0);
+  EXPECT_EQ(str_v1.size(), v1_size);
+
+  // Ensure bytes after '\0' are taken into account for mismatches.
+  string v2 = "aaa one";
+  string v3 = "aaa two";
+  v2[3] = '\0';
+  v3[3] = '\0';
+  String str_v2{v2};
+  String str_v3{v3};
+  EXPECT_EQ(str_v2.compare(str_v3), -1);
+  EXPECT_EQ(str_v2.size(), 7);
+  // strcmp won't be able to detect the mismatch
+  EXPECT_EQ(strcmp(v2.data(), v3.data()), 0);
+  // string::compare can handle \0 since it knows size
+  EXPECT_LT(v2.compare(v3), 0);
+
+  // If there is mismatch before '\0', should still handle it.
+  string v4 = "acc one";
+  string v5 = "abb two";
+  v4[3] = '\0';
+  v5[3] = '\0';
+  String str_v4{v4};
+  String str_v5{v5};
+  EXPECT_GT(str_v4.compare(str_v5), 0);
+  EXPECT_EQ(str_v4.size(), 7);
+  // strcmp is able to detect the mismatch
+  EXPECT_GT(strcmp(v4.data(), v5.data()), 0);
+  // string::compare can handle \0 since it knows size
+  EXPECT_GT(v4.compare(v5), 0);
+}
+
+TEST(String, compare_same_memory_region_different_size) {
+  using namespace std;
+  string source = "a string";
+  String str_source{source};
+  char* memory = const_cast<char*>(str_source.data());
+  EXPECT_EQ(str_source.compare(memory), 0);
+  // This changes the string size
+  memory[2] = '\0';
+  // memory is logically shorter now
+  EXPECT_GT(str_source.compare(memory), 0);
+}
+
+TEST(String, compare) {
+  using namespace std;
+  constexpr auto mismatch1_cstr = "a string but longer";
+  string source = "a string";
+  string mismatch1 = mismatch1_cstr;
+  string mismatch2 = "a strin";
+  string mismatch3 = "a b";
+  string mismatch4 = "a t";
+  String str_source{source};
+  String str_mismatch1{mismatch1_cstr};
+  String str_mismatch2{mismatch2};
+  String str_mismatch3{mismatch3};
+  String str_mismatch4{mismatch4};
+
+  // compare with string
+  EXPECT_EQ(str_source.compare(source), 0);
+  EXPECT_TRUE(str_source == source);
+  EXPECT_TRUE(source == str_source);
+  EXPECT_TRUE(str_source <= source);
+  EXPECT_TRUE(source <= str_source);
+  EXPECT_TRUE(str_source >= source);
+  EXPECT_TRUE(source >= str_source);
+  EXPECT_LT(str_source.compare(mismatch1), 0);
+  EXPECT_TRUE(str_source < mismatch1);
+  EXPECT_TRUE(mismatch1 != str_source);
+  EXPECT_GT(str_source.compare(mismatch2), 0);
+  EXPECT_TRUE(str_source > mismatch2);
+  EXPECT_TRUE(mismatch2 < str_source);
+  EXPECT_GT(str_source.compare(mismatch3), 0);
+  EXPECT_TRUE(str_source > mismatch3);
+  EXPECT_LT(str_source.compare(mismatch4), 0);
+  EXPECT_TRUE(str_source < mismatch4);
+  EXPECT_TRUE(mismatch4 > str_source);
+
+  // compare with char*
+  EXPECT_EQ(str_source.compare(source.data()), 0);
+  EXPECT_TRUE(str_source == source.data());
+  EXPECT_TRUE(source.data() == str_source);
+  EXPECT_TRUE(str_source <= source.data());
+  EXPECT_TRUE(source <= str_source.data());
+  EXPECT_TRUE(str_source >= source.data());
+  EXPECT_TRUE(source >= str_source.data());
+  EXPECT_LT(str_source.compare(mismatch1.data()), 0);
+  EXPECT_TRUE(str_source < mismatch1.data());
+  EXPECT_TRUE(str_source != mismatch1.data());
+  EXPECT_TRUE(mismatch1.data() != str_source);
+  EXPECT_GT(str_source.compare(mismatch2.data()), 0);
+  EXPECT_TRUE(str_source > mismatch2.data());
+  EXPECT_TRUE(mismatch2.data() < str_source);
+  EXPECT_GT(str_source.compare(mismatch3.data()), 0);
+  EXPECT_TRUE(str_source > mismatch3.data());
+  EXPECT_LT(str_source.compare(mismatch4.data()), 0);
+  EXPECT_TRUE(str_source < mismatch4.data());
+  EXPECT_TRUE(mismatch4.data() > str_source);
+
+  // compare with String
+  EXPECT_LT(str_source.compare(str_mismatch1), 0);
+  EXPECT_TRUE(str_source < str_mismatch1);
+  EXPECT_GT(str_source.compare(str_mismatch2), 0);
+  EXPECT_TRUE(str_source > str_mismatch2);
+  EXPECT_GT(str_source.compare(str_mismatch3), 0);
+  EXPECT_TRUE(str_source > str_mismatch3);
+  EXPECT_LT(str_source.compare(str_mismatch4), 0);
+  EXPECT_TRUE(str_source < str_mismatch4);
+}
+
+TEST(String, c_str) {
+  using namespace std;
+  string source = "this is a string";
+  string mismatch = "mismatch";
+  String s{source};
+
+  EXPECT_EQ(std::strcmp(s.c_str(), source.data()), 0);
+  EXPECT_NE(std::strcmp(s.c_str(), mismatch.data()), 0);
+}
+
+TEST(String, hash) {
+  using namespace std;
+  string source = "this is a string";
+  String s{source};
+  std::hash<String>()(s);
+
+  std::unordered_map<String, std::string> map;
+  String k1{string{"k1"}};
+  string v1{"v1"};
+  String k2{string{"k2"}};
+  string v2{"v2"};
+  map[k1] = v1;
+  map[k2] = v2;
+
+  EXPECT_EQ(map[k1], v1);
+  EXPECT_EQ(map[k2], v2);
+}
+
+TEST(String, Cast) {
+  using namespace std;
+  string source = "this is a string";
+  String s{source};
+  ObjectRef r = s;
+  String s2 = Downcast<String>(r);
+}
+
+TEST(String, Concat) {
+  String s1("hello");
+  String s2("world");
+  std::string s3("world");
+  String res1 = s1 + s2;
+  String res2 = s1 + s3;
+  String res3 = s3 + s1;
+  String res4 = s1 + "world";
+  String res5 = "world" + s1;
+
+  EXPECT_EQ(res1.compare("helloworld"), 0);
+  EXPECT_EQ(res2.compare("helloworld"), 0);
+  EXPECT_EQ(res3.compare("worldhello"), 0);
+  EXPECT_EQ(res4.compare("helloworld"), 0);
+  EXPECT_EQ(res5.compare("worldhello"), 0);
+}
+}  // namespace
\ No newline at end of file

Reply via email to