This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 170302bab2 [FFI][DOCS] Initial bringup of cpp docs (#18279)
170302bab2 is described below

commit 170302bab2046faec8d2effe4a25e7af3c2446be
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Sep 7 21:48:44 2025 -0400

    [FFI][DOCS] Initial bringup of cpp docs (#18279)
    
    This PR brings up initial version of cpp api docs.
---
 ffi/docs/.gitignore                           |   2 +-
 ffi/docs/Makefile                             |  13 ++--
 ffi/docs/README.md                            |  11 +++
 ffi/docs/conf.py                              |  40 ++++++++++
 ffi/docs/guides/cpp_guide.md                  |   2 +
 ffi/docs/index.rst                            |   6 ++
 ffi/docs/reference/cpp/index.rst              | 107 ++++++++++++++++++++++++++
 ffi/docs/requirements.txt                     |   2 +
 ffi/include/tvm/ffi/any.h                     |  78 +++++++++++++++----
 ffi/include/tvm/ffi/base_details.h            |   4 +-
 ffi/include/tvm/ffi/c_api.h                   |  38 ++++++---
 ffi/include/tvm/ffi/container/array.h         |  78 ++++++++++++++++---
 ffi/include/tvm/ffi/container/map.h           |  58 +++++++++++++-
 ffi/include/tvm/ffi/container/shape.h         |   7 +-
 ffi/include/tvm/ffi/container/tensor.h        |  35 ++++++---
 ffi/include/tvm/ffi/container/tuple.h         |  51 +++++++++++-
 ffi/include/tvm/ffi/container/variant.h       |  50 +++++++++++-
 ffi/include/tvm/ffi/dtype.h                   |  12 ++-
 ffi/include/tvm/ffi/error.h                   |  39 +++++++++-
 ffi/include/tvm/ffi/extra/base64.h            |   2 +-
 ffi/include/tvm/ffi/extra/c_env_api.h         |   5 +-
 ffi/include/tvm/ffi/extra/json.h              |   2 +-
 ffi/include/tvm/ffi/extra/module.h            |  10 ++-
 ffi/include/tvm/ffi/extra/serialization.h     |   2 +-
 ffi/include/tvm/ffi/extra/structural_equal.h  |   2 +-
 ffi/include/tvm/ffi/extra/structural_hash.h   |   2 +-
 ffi/include/tvm/ffi/function.h                |  82 +++++++++++++++++---
 ffi/include/tvm/ffi/memory.h                  |  46 ++++++-----
 ffi/include/tvm/ffi/object.h                  |  68 ++++++++++++----
 ffi/include/tvm/ffi/optional.h                |   4 +-
 ffi/include/tvm/ffi/reflection/access_path.h  |  61 +++++++++++++++
 ffi/include/tvm/ffi/reflection/accessor.h     |  41 +++++++++-
 ffi/include/tvm/ffi/reflection/creator.h      |   8 ++
 ffi/include/tvm/ffi/reflection/registry.h     |  74 +++++++++++++++---
 ffi/include/tvm/ffi/string.h                  |  54 ++++++++++---
 ffi/include/tvm/ffi/type_traits.h             |  27 +++++++
 python/tvm/tir/tensor_intrin/riscv_cpu.py     |   5 +-
 src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc |   2 +-
 38 files changed, 983 insertions(+), 147 deletions(-)

diff --git a/ffi/docs/.gitignore b/ffi/docs/.gitignore
index 0b4a3621d9..d7ab85b91f 100644
--- a/ffi/docs/.gitignore
+++ b/ffi/docs/.gitignore
@@ -1,2 +1,2 @@
 _build
-**/generated/*.rst
+**/generated/*
diff --git a/ffi/docs/Makefile b/ffi/docs/Makefile
index ff28cb0cbc..51e4de21d3 100644
--- a/ffi/docs/Makefile
+++ b/ffi/docs/Makefile
@@ -27,14 +27,15 @@ help:
 
 .PHONY: help Makefile livehtml clean
 
-# Catch-all target: route all unknown targets to Sphinx using the new
-# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
-%: Makefile
-       @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
-
 livehtml:
-       @sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+       @sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 
--ignore reference/cpp/generated
 
 clean:
        rm -rf $(BUILDDIR)
        rm -rf reference/python/generated
+       rm -rf reference/cpp/generated
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+       @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/ffi/docs/README.md b/ffi/docs/README.md
index cf96b6f6d4..39fff194df 100644
--- a/ffi/docs/README.md
+++ b/ffi/docs/README.md
@@ -33,3 +33,14 @@ Then build the doc
 ```bash
 make livehtml
 ```
+
+## Build with C++ Docs
+
+To build with C++ docs, we need to first install Doxygen. Then
+set the environment variable `BUILD_CPP_DOCS=1`, to turn on c++ docs.
+
+```bash
+BUILD_CPP_DOCS=1 make livehtml
+```
+
+Building c++ docs can take longer, so it is not on by default.
diff --git a/ffi/docs/conf.py b/ffi/docs/conf.py
index b97ed78ef8..139254fd97 100644
--- a/ffi/docs/conf.py
+++ b/ffi/docs/conf.py
@@ -23,6 +23,9 @@ import tomli
 
 os.environ["TVM_FFI_BUILD_DOCS"] = "1"
 
+build_exhale = os.environ.get("BUILD_CPP_DOCS", "0") == "1"
+
+
 # -- General configuration ------------------------------------------------
 
 # Load version from pyproject.toml
@@ -38,6 +41,7 @@ release = __version__
 # -- Extensions and extension configurations --------------------------------
 
 extensions = [
+    "breathe",
     "myst_parser",
     "nbsphinx",
     "autodocsumm",
@@ -48,6 +52,7 @@ extensions = [
     "sphinx.ext.mathjax",
     "sphinx.ext.napoleon",
     "sphinx.ext.viewcode",
+    "sphinx.ext.ifconfig",
     "sphinx_copybutton",
     "sphinx_reredirects",
     "sphinx_tabs.tabs",
@@ -56,6 +61,40 @@ extensions = [
     "sphinxcontrib.mermaid",
 ]
 
+if build_exhale:
+    extensions.append("exhale")
+
+breathe_default_project = "tvm-ffi"
+
+breathe_projects = {"tvm-ffi": "./_build/doxygen/xml"}
+
+exhaleDoxygenStdin = """
+INPUT = ../include
+PREDEFINED  += TVM_FFI_DLL= TVM_FFI_INLINE= TVM_FFI_EXTRA_CXX_API= 
__cplusplus=201703
+
+EXCLUDE_SYMBOLS   += *details*  *TypeTraits* std \
+                         *use_default_type_traits_v* *is_optional_type_v* 
*operator* \
+
+EXCLUDE_PATTERNS   += *details.h
+ENABLE_PREPROCESSING   = YES
+MACRO_EXPANSION        = YES
+"""
+
+exhaleAfterTitleDescription = """
+This page contains the full API index for the C++ API.
+"""
+
+# Setup the exhale extension
+exhale_args = {
+    "containmentFolder": "reference/cpp/generated",
+    "rootFileName": "index.rst",
+    "doxygenStripFromPath": "../include",
+    "rootFileTitle": "Full API Index",
+    "createTreeView": True,
+    "exhaleExecutesDoxygen": True,
+    "exhaleDoxygenStdin": exhaleDoxygenStdin,
+    "afterTitleDescription": exhaleAfterTitleDescription,
+}
 nbsphinx_allow_errors = True
 nbsphinx_execute = "never"
 
@@ -69,6 +108,7 @@ myst_enable_extensions = [
     "colon_fence",
     "html_image",
     "linkify",
+    "attrs_block",
     "substitution",
 ]
 
diff --git a/ffi/docs/guides/cpp_guide.md b/ffi/docs/guides/cpp_guide.md
index fdbd7f7d7b..6b976dd635 100644
--- a/ffi/docs/guides/cpp_guide.md
+++ b/ffi/docs/guides/cpp_guide.md
@@ -14,6 +14,8 @@
 <!--- KIND, either express or implied.  See the License for the -->
 <!--- specific language governing permissions and limitations -->
 <!--- under the License. -->
+{#cpp-guide}
+
 # C++ Guide
 
 This guide introduces the tvm-ffi C++ API.
diff --git a/ffi/docs/index.rst b/ffi/docs/index.rst
index 0739f8c2ee..643ee41791 100644
--- a/ffi/docs/index.rst
+++ b/ffi/docs/index.rst
@@ -18,6 +18,10 @@
 Apache TVM FFI Documentation
 ============================
 
+Welcome to the documentation for TVM FFI. You can get started by reading the 
get started section,
+or reading through the guides and concepts sections.
+
+
 .. toctree::
    :maxdepth: 1
    :caption: Get Started
@@ -40,8 +44,10 @@ Apache TVM FFI Documentation
 
    concepts/abi_overview.md
 
+
 .. toctree::
    :maxdepth: 1
    :caption: Reference
 
    reference/python/index.rst
+   reference/cpp/index.rst
diff --git a/ffi/docs/reference/cpp/index.rst b/ffi/docs/reference/cpp/index.rst
new file mode 100644
index 0000000000..ac9b1d73f9
--- /dev/null
+++ b/ffi/docs/reference/cpp/index.rst
@@ -0,0 +1,107 @@
+..  Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+..    http://www.apache.org/licenses/LICENSE-2.0
+
+..  Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+C++ API
+=======
+
+This page contains the API reference for the C++ API. The full API index below
+can be a bit dense, so we recommend the following tips first:
+
+- Please read the :ref:`C++ Guide<cpp-guide>` for a high-level overview of the 
C++ API.
+
+  - The C++ Guide and examples will likely be sufficient to get started with 
most use cases.
+
+- The :ref:`cpp-key-classes` lists the key classes that are most commonly used.
+- You can go to the Full API Index at the bottom of this page to access the 
full list of APIs.
+
+  - We usually group the APIs by files. You can look at the file hierarchy in 
the
+    full API index and navigate to the specific file to find the APIs in that 
file.
+
+Header Organization
+-------------------
+
+The C++ APIs are organized into the following folders:
+
+.. list-table::
+   :header-rows: 1
+   :widths: 30 70
+
+   * - Folder
+     - Description
+   * - ``tvm/ffi/``
+     - Core functionalities that support Function, Any, Object, etc.
+   * - ``tvm/ffi/container/``
+     - Additional container types such as Array, Map, Shape, Tensor, Variant 
...
+   * - ``tvm/ffi/reflection/``
+     - Reflection support for function and type information registration.
+   * - ``tvm/ffi/extra/``
+     - Extra APIs that are built on top.
+
+
+.. _cpp-key-classes:
+
+Key Classes
+-----------
+
+.. list-table::
+   :header-rows: 1
+   :widths: 30 70
+
+   * - Class
+     - Description
+   * - :cpp:class:`tvm::ffi::Function`
+     - Type-erased function that implements the ABI.
+   * - :cpp:class:`tvm::ffi::Any`
+     - Type-erased container for any supported value.
+   * - :cpp:class:`tvm::ffi::AnyView`
+     - Lightweight view of Any without ownership.
+   * - :cpp:class:`tvm::ffi::Object`
+     - Base class for all heap-allocated FFI objects.
+   * - :cpp:class:`tvm::ffi::ObjectRef`
+     - Reference class for objects.
+   * - :cpp:class:`tvm::ffi::Tensor`
+     - Multi-dimensional tensor with DLPack support.
+   * - :cpp:class:`tvm::ffi::Shape`
+     - Tensor shape container.
+   * - :cpp:class:`tvm::ffi::Module`
+     - Dynamic library module that can load exported functions.
+   * - :cpp:class:`tvm::ffi::String`
+     - String type for FFI.
+   * - :cpp:class:`tvm::ffi::Bytes`
+     - Byte array type.
+   * - :cpp:class:`tvm::ffi::Array`
+     - Dynamic array container.
+   * - :cpp:class:`tvm::ffi::Tuple`
+     - Heterogeneous tuple container.
+   * - :cpp:class:`tvm::ffi::Map`
+     - Key-value map container.
+   * - :cpp:class:`tvm::ffi::Optional`
+     - Optional value wrapper.
+   * - :cpp:class:`tvm::ffi::Variant`
+     - Type-safe union container.
+
+
+
+.. _cpp-full-api-index:
+
+Full API Index
+--------------
+
+.. toctree::
+   :maxdepth: 2
+
+   generated/index.rst
diff --git a/ffi/docs/requirements.txt b/ffi/docs/requirements.txt
index 0d09ef1815..74784b5153 100644
--- a/ffi/docs/requirements.txt
+++ b/ffi/docs/requirements.txt
@@ -1,4 +1,6 @@
 autodocsumm
+exhale
+breathe
 linkify-it-py
 matplotlib
 myst-parser
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index ed34328d1e..738adc4f86 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -52,7 +52,7 @@ class AnyView {
   friend class Any;
 
  public:
-  // NOTE: the following two functions uses styl style
+  // NOTE: the following functions use style
   // since they are common functions appearing in FFI.
   /*!
    * \brief Reset any view to None
@@ -64,13 +64,13 @@ class AnyView {
     data_.v_int64 = 0;
   }
   /*!
-   * \brief Swap this array with another Object
-   * \param other The other Object
+   * \brief Swap this AnyView with another AnyView
+   * \param other The other AnyView
    */
   TVM_FFI_INLINE void swap(AnyView& other) noexcept { std::swap(data_, 
other.data_); }
   /*! \return the internal type index */
   TVM_FFI_INLINE int32_t type_index() const noexcept { return 
data_.type_index; }
-  // default constructors
+  /*! \brief Default constructor */
   AnyView() {
     data_.type_index = TypeIndex::kTVMFFINone;
     data_.zero_padding = 0;
@@ -78,8 +78,11 @@ class AnyView {
   }
   ~AnyView() = default;
   // constructors from any view
+  /*! \brief Copy constructor */
   AnyView(const AnyView&) = default;
+  /*! \brief Copy assignment operator */
   AnyView& operator=(const AnyView&) = default;
+  /*! \brief Move constructor */
   AnyView(AnyView&& other) : data_(other.data_) {
     other.data_.type_index = TypeIndex::kTVMFFINone;
     other.data_.zero_padding = 0;
@@ -90,11 +93,20 @@ class AnyView {
     AnyView(std::move(other)).swap(*this);  // NOLINT(*)
     return *this;
   }
-  // constructor from general types
+  /*!
+   * \brief Constructor from a general type.
+   * \tparam T The type to convert from.
+   * \param other The value to convert from.
+   */
   template <typename T, typename = 
std::enable_if_t<TypeTraits<T>::convert_enabled>>
   AnyView(const T& other) {  // NOLINT(*)
     TypeTraits<T>::CopyToAnyView(other, &data_);
   }
+  /*!
+   * \brief Assign from a general type.
+   * \tparam T The type to convert from.
+   * \param other The value to convert from.
+   */
   template <typename T, typename = 
std::enable_if_t<TypeTraits<T>::convert_enabled>>
   TVM_FFI_INLINE AnyView& operator=(const T& other) {  // NOLINT(*)
     // copy-and-swap idiom
@@ -117,7 +129,7 @@ class AnyView {
       return std::optional<T>(std::nullopt);
     }
   }
-  /*
+  /*!
    * \brief Shortcut of as Object to cast to a const pointer when T is an 
Object.
    *
    * \tparam T The object type.
@@ -128,7 +140,7 @@ class AnyView {
     return this->as<const T*>().value_or(nullptr);
   }
 
-  /**
+  /*!
    * \brief Cast to a type T.
    *
    * \tparam T The type to cast to.
@@ -243,44 +255,71 @@ class Any {
     data_.v_int64 = 0;
   }
   /*!
-   * \brief Swap this array with another Object
-   * \param other The other Object
+   * \brief Swap this Any with another Any
+   * \param other The other Any
    */
   TVM_FFI_INLINE void swap(Any& other) noexcept { std::swap(data_, 
other.data_); }
   /*! \return the internal type index */
   TVM_FFI_INLINE int32_t type_index() const noexcept { return 
data_.type_index; }
-  // default constructors
+  /*!
+   * \brief Default constructor
+   */
   Any() {
     data_.type_index = TypeIndex::kTVMFFINone;
     data_.zero_padding = 0;
     data_.v_int64 = 0;
   }
+  /*!
+   * \brief Destructor
+   */
   ~Any() { this->reset(); }
-  // constructors from Any
+  /*!
+   * \brief Constructor from another Any
+   * \param other The other Any
+   */
   Any(const Any& other) : data_(other.data_) {
     if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
       details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj);
     }
   }
+  /*!
+   * \brief Move constructor from another Any
+   * \param other The other Any
+   */
   Any(Any&& other) : data_(other.data_) {
     other.data_.type_index = TypeIndex::kTVMFFINone;
     other.data_.zero_padding = 0;
     other.data_.v_int64 = 0;
   }
+  /*!
+   * \brief Assign from another Any
+   * \param other The other Any
+   */
   TVM_FFI_INLINE Any& operator=(const Any& other) {
     // copy-and-swap idiom
     Any(other).swap(*this);  // NOLINT(*)
     return *this;
   }
+  /*!
+   * \brief Move assign from another Any
+   * \param other The other Any
+   */
   TVM_FFI_INLINE Any& operator=(Any&& other) {
     // copy-and-swap idiom
     Any(std::move(other)).swap(*this);  // NOLINT(*)
     return *this;
   }
-  // convert from/to AnyView
+  /*!
+   * \brief Constructor from another AnyView
+   * \param other The other AnyView
+   */
   Any(const AnyView& other) : data_(other.data_) {  // NOLINT(*)
     details::InplaceConvertAnyViewToAny(&data_);
   }
+  /*!
+   * \brief Assign from another AnyView
+   * \param other The other AnyView
+   */
   TVM_FFI_INLINE Any& operator=(const AnyView& other) {
     // copy-and-swap idiom
     Any(other).swap(*this);  // NOLINT(*)
@@ -288,11 +327,18 @@ class Any {
   }
   /*! \brief Any can be converted to AnyView in zero cost. */
   operator AnyView() const { return AnyView::CopyFromTVMFFIAny(data_); }
-  // constructor from general types
+  /*!
+   * \brief Constructor from a general type
+   * \tparam T The value type of the other
+   */
   template <typename T, typename = 
std::enable_if_t<TypeTraits<T>::convert_enabled>>
   Any(T other) {  // NOLINT(*)
     TypeTraits<T>::MoveToAny(std::move(other), &data_);
   }
+  /*!
+   * \brief Assignment from a general type
+   * \tparam T The value type of the other
+   */
   template <typename T, typename = 
std::enable_if_t<TypeTraits<T>::convert_enabled>>
   TVM_FFI_INLINE Any& operator=(T other) {  // NOLINT(*)
     // copy-and-swap idiom
@@ -342,7 +388,7 @@ class Any {
     }
   }
 
-  /*
+  /*!
    * \brief Shortcut of as Object to cast to a const pointer when T is an 
Object.
    *
    * \tparam T The object type.
@@ -405,7 +451,7 @@ class Any {
       return TypeTraits<T>::TryCastFromAnyView(&data_);
     }
   }
-  /*
+  /*!
    * \brief Check if the two Any are same type and value in shallow comparison.
    * \param other The other Any
    * \return True if the two Any are same type and value, false otherwise.
@@ -415,7 +461,7 @@ class Any {
            data_.zero_padding == other.data_.zero_padding && data_.v_int64 == 
other.data_.v_int64;
   }
 
-  /*
+  /*!
    * \brief Check if any and ObjectRef are same type and value in shallow 
comparison.
    * \param other The other ObjectRef
    * \return True if the two Any are same type and value, false otherwise.
diff --git a/ffi/include/tvm/ffi/base_details.h 
b/ffi/include/tvm/ffi/base_details.h
index 7c96b091d7..80cd889ddb 100644
--- a/ffi/include/tvm/ffi/base_details.h
+++ b/ffi/include/tvm/ffi/base_details.h
@@ -19,7 +19,7 @@
 /*!
  * \file tvm/ffi/base_details.h
  * \brief Internal detail utils that can be used by files in tvm/ffi.
- * \note details header are for internal use only
+ * \note details headers are for internal use only
  *       and not to be directly used by user.
  */
 #ifndef TVM_FFI_BASE_DETAILS_H_
@@ -47,6 +47,7 @@
 #endif
 
 #endif
+/// \cond Doxygen_Suppress
 
 #if defined(_MSC_VER)
 #define TVM_FFI_INLINE [[msvc::forceinline]] inline
@@ -268,4 +269,5 @@ TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const 
TVMFFIAny* data) {
 }  // namespace details
 }  // namespace ffi
 }  // namespace tvm
+/// \endcond
 #endif  // TVM_FFI_BASE_DETAILS_H_
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 2a694fc4ad..5d67fcd221 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -202,7 +202,7 @@ typedef enum {
  * \brief C-based type of all FFI object header that allocates on heap.
  * \note TVMFFIObject and TVMFFIAny share the common type_index header
  */
-typedef struct TVMFFIObject {
+typedef struct {
   /*!
    * \brief type index of the object.
    * \note The type index of Object and Any are shared in FFI.
@@ -223,7 +223,7 @@ typedef struct TVMFFIObject {
      * \param flags The flags to indicate deletion behavior.
      * \sa TVMFFIObjectDeleterFlagBitMask
      */
-    void (*deleter)(struct TVMFFIObject* self, int flags);
+    void (*deleter)(void* self, int flags);
     /*!
      * \brief auxilary field to TVMFFIObject is always 8 bytes aligned.
      * \note This helps us to ensure cross platform compatibility.
@@ -238,7 +238,7 @@ typedef struct TVMFFIObject {
  * Any value can hold on stack values like int,
  * as well as reference counted pointers to object.
  */
-typedef struct TVMFFIAny {
+typedef struct {
   /*!
    * \brief type index of the object.
    * \note The type index of Object and Any are shared in FFI.
@@ -281,7 +281,9 @@ typedef struct TVMFFIAny {
  *       The FFI binding should be careful when treating this ABI.
  */
 typedef struct {
+  /*! \brief The data pointer. */
   const char* data;
+  /*! \brief The size of the data. */
   size_t size;
 } TVMFFIByteArray;
 
@@ -289,7 +291,9 @@ typedef struct {
  * \brief Shape cell used in shape object following header.
  */
 typedef struct {
+  /*! \brief The data pointer. */
   const int64_t* data;
+  /*! \brief The size of the data. */
   size_t size;
 } TVMFFIShapeCell;
 
@@ -442,7 +446,7 @@ TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const 
TVMFFIByteArray* name, TVMFFIObjec
 
 /*!
  * \brief Convert an AnyView to an owned Any.
- * \param any The AnyView to convert.
+ * \param any_view The AnyView to convert.
  * \param out The output Any, must be an empty object.
  * \return 0 on success, nonzero on failure.
  */
@@ -724,9 +728,9 @@ typedef struct {
    *
    * Possible values:
    *
-   *  - TVMFFITypeIndex::kTVMFFIObject for general objects
-   *    - The value is nullable when kTVMFFIObject is chosen
-   * - static object type kinds such as Map, Dict, String
+   * - TVMFFITypeIndex::kTVMFFIObject for general objects.
+   *   The value is nullable when kTVMFFIObject is chosen.
+   * - Static object type kinds such as Map, Dict, String
    * - POD type index, note it does not give information about storage size of 
the field.
    * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info
    *   about the field.
@@ -793,7 +797,7 @@ typedef struct {
   TVMFFISEqHashKind structural_eq_hash_kind;
 } TVMFFITypeMetadata;
 
-/*
+/*!
  * \brief Column array that stores extra attributes about types
  *
  * The attributes stored in a column array that can be looked up by type index.
@@ -813,7 +817,11 @@ typedef struct {
 /*!
  * \brief Runtime type information for object type checking.
  */
+#ifdef __cplusplus
+struct TVMFFITypeInfo {
+#else
 typedef struct TVMFFITypeInfo {
+#endif
   /*!
    *\brief The runtime type index,
    * It can be allocated during runtime if the type is dynamic.
@@ -842,7 +850,11 @@ typedef struct TVMFFITypeInfo {
   const TVMFFIMethodInfo* methods;
   /*! \brief The extra information of the type. */
   const TVMFFITypeMetadata* metadata;
+#ifdef __cplusplus
+};
+#else
 } TVMFFITypeInfo;
+#endif
 
 /*!
  * \brief Register the function to runtime's global table.
@@ -860,7 +872,7 @@ TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const 
TVMFFIByteArray* name, TVMFFIObjec
  * This is the same as TVMFFIFunctionSetGlobal but with method info that can 
provide extra
  * metadata used in the runtime.
  * \param method_info The method info to be registered.
- * \param override Whether to allow overriding an already registered function.
+ * \param allow_override Whether to allow overriding an already registered 
function.
  * \return 0 on success, nonzero on failure.
  */
 TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* 
method_info,
@@ -923,19 +935,21 @@ TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const 
char* filename, int lin
 
 /*!
  * \brief Initialize the type info during runtime.
+ *
  * When the function is first called for a type,
  * it will register the type to the type table in the runtime.
  * If the static_tindex is non-negative, the function will
  * allocate a runtime type index.
  * Otherwise, we will populate the type table and return the static index.
+ *
  * \param type_key The type key.
+ * \param type_depth The type depth.
  * \param static_type_index Static type index if any, can be -1, which means 
this is a dynamic index
  * \param num_child_slots Number of slots reserved for its children.
  * \param child_slots_can_overflow Whether to allow child to overflow the 
slots.
  * \param parent_type_index Parent type index, pass in -1 if it is root.
- * \param result The output type index
  *
- * \return 0 if success, -1 if error occured
+ * \return The allocated type index.
  */
 TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key,
                                               int32_t static_type_index, 
int32_t type_depth,
@@ -974,7 +988,7 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle 
obj) {
 
 /*!
  * \brief Get the content of a small string in bytearray format.
- * \param obj The object handle.
+ * \param value The value to get the content of the small string in bytearray 
format.
  * \return The content of the small string in bytearray format.
  */
 inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* 
value) {
diff --git a/ffi/include/tvm/ffi/container/array.h 
b/ffi/include/tvm/ffi/container/array.h
index 180c870ccb..077a55d6d1 100644
--- a/ffi/include/tvm/ffi/container/array.h
+++ b/ffi/include/tvm/ffi/container/array.h
@@ -21,7 +21,7 @@
  * \file tvm/ffi/container/array.h
  * \brief Array type.
  *
- * tvm::ffi::Array<Any> is an erased type that contains list of content
+ * tvm::ffi::Array<Any> is an erased type that contains a list of content
  */
 #ifndef TVM_FFI_CONTAINER_ARRAY_H_
 #define TVM_FFI_CONTAINER_ARRAY_H_
@@ -41,7 +41,7 @@
 namespace tvm {
 namespace ffi {
 
-/*! \brief array node content in array */
+/*! \brief Array node content in array */
 class ArrayObj : public Object, public details::InplaceArrayBase<ArrayObj, 
TVMFFIAny> {
  public:
   ~ArrayObj() {
@@ -106,7 +106,7 @@ class ArrayObj : public Object, public 
details::InplaceArrayBase<ArrayObj, TVMFF
   static ObjectPtr<ArrayObj> CopyFrom(int64_t cap, ArrayObj* from) {
     int64_t size = from->size_;
     if (size > cap) {
-      TVM_FFI_THROW(ValueError) << "not enough capacity";
+      TVM_FFI_THROW(ValueError) << "Not enough capacity";
     }
     ObjectPtr<ArrayObj> p = ArrayObj::Empty(cap);
     Any* write = p->MutableBegin();
@@ -127,7 +127,7 @@ class ArrayObj : public Object, public 
details::InplaceArrayBase<ArrayObj, TVMFF
   static ObjectPtr<ArrayObj> MoveFrom(int64_t cap, ArrayObj* from) {
     int64_t size = from->size_;
     if (size > cap) {
-      TVM_FFI_THROW(RuntimeError) << "not enough capacity";
+      TVM_FFI_THROW(RuntimeError) << "Not enough capacity";
     }
     ObjectPtr<ArrayObj> p = ArrayObj::Empty(cap);
     Any* write = p->MutableBegin();
@@ -155,10 +155,12 @@ class ArrayObj : public Object, public 
details::InplaceArrayBase<ArrayObj, TVMFF
     return p;
   }
 
+  /// \cond Doxygen_Suppress
   static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray;
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFIArray;
   static const constexpr bool _type_final = true;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ArrayObj, Object);
+  /// \endcond
 
  private:
   /*! \return Size of initialized memory, used by InplaceArrayBase. */
@@ -172,6 +174,7 @@ class ArrayObj : public Object, public 
details::InplaceArrayBase<ArrayObj, TVMFF
 
   /*!
    * \brief Emplace a new element at the back of the array
+   * \param idx The index of the element.
    * \param args The arguments to construct the new element
    */
   template <typename... Args>
@@ -328,6 +331,11 @@ struct is_valid_iterator<Optional<T>, IterType> : 
is_valid_iterator<T, IterType>
 template <typename IterType>
 struct is_valid_iterator<Any, IterType> : std::true_type {};
 
+/*!
+ * \brief Check whether IterType is valid iterator for T.
+ * \tparam T The type.
+ * \tparam IterType The type of iterator.
+ */
 template <typename T, typename IterType>
 inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, 
IterType>::value;
 
@@ -351,32 +359,69 @@ inline constexpr bool is_valid_iterator_v = 
is_valid_iterator<T, IterType>::valu
 template <typename T, typename = typename 
std::enable_if_t<details::storage_enabled_v<T>>>
 class Array : public ObjectRef {
  public:
+  /*! \brief The value type of the array */
   using value_type = T;
   // constructors
   /*!
    * \brief default constructor
    */
   Array() { data_ = ArrayObj::Empty(); }
+  /*!
+   * \brief Move constructor
+   * \param other The other array
+   */
   Array(Array<T>&& other) : ObjectRef(std::move(other.data_)) {}
+  /*!
+   * \brief Copy constructor
+   * \param other The other array
+   */
   Array(const Array<T>& other) : ObjectRef(other.data_) {}
+  /*!
+   * \brief Constructor from another array
+   * \param other The other array
+   * \tparam U The value type of the other array
+   */
   template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
   Array(Array<U>&& other) : ObjectRef(std::move(other.data_)) {}
+  /*!
+   * \brief Constructor from another array
+   * \param other The other array
+   * \tparam U The value type of the other array
+   */
   template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
   Array(const Array<U>& other) : ObjectRef(other.data_) {}
 
+  /*!
+   * \brief Move assignment from another array
+   * \param other The other array
+   */
   TVM_FFI_INLINE Array<T>& operator=(Array<T>&& other) {
     data_ = std::move(other.data_);
     return *this;
   }
+  /*!
+   * \brief Assignment from another array
+   * \param other The other array
+   */
   TVM_FFI_INLINE Array<T>& operator=(const Array<T>& other) {
     data_ = other.data_;
     return *this;
   }
+  /*!
+   * \brief Move assignment from another array
+   * \param other The other array
+   * \tparam U The value type of the other array
+   */
   template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
   TVM_FFI_INLINE Array<T>& operator=(Array<U>&& other) {
     data_ = std::move(other.data_);
     return *this;
   }
+  /*!
+   * \brief Assignment from another array
+   * \param other The other array
+   * \tparam U The value type of the other array
+   */
   template <typename U, typename = 
std::enable_if_t<details::type_contains_v<T, U>>>
   TVM_FFI_INLINE Array<T>& operator=(const Array<U>& other) {
     data_ = other.data_;
@@ -384,7 +429,7 @@ class Array : public ObjectRef {
   }
 
   /*!
-   * \brief constructor from pointer
+   * \brief Constructor from pointer
    * \param n the container pointer
    */
   explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
@@ -427,12 +472,21 @@ class Array : public ObjectRef {
 
  public:
   // iterators
+  /// \cond Doxygen_Suppress
   struct ValueConverter {
     using ResultType = T;
+    /*!
+     * \brief Convert any to T
+     * \param n The any value to convert
+     * \return The converted value
+     */
     static T convert(const Any& n) { return 
details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(n); }
   };
+  /// \endcond
 
+  /*! \brief The iterator type of the array */
   using iterator = details::IterAdapter<ValueConverter, const Any*>;
+  /*! \brief The reverse iterator type of the array */
   using reverse_iterator = details::ReverseIterAdapter<ValueConverter, const 
Any*>;
 
   /*! \return begin iterator */
@@ -515,6 +569,10 @@ class Array : public ObjectRef {
     p->EmplaceInit(p->size_++, item);
   }
 
+  /*!
+   * \brief Emplace a new element at the back of the array
+   * \param args The arguments to construct the new element
+   */
   template <typename... Args>
   void emplace_back(Args&&... args) {
     ArrayObj* p = CopyOnWrite(1);
@@ -660,7 +718,7 @@ class Array : public ObjectRef {
       p->clear();
     }
   }
-
+  /// \cond Doxygen_Suppress
   template <typename... Args>
   static size_t CalcCapacityImpl() {
     return 0;
@@ -690,6 +748,7 @@ class Array : public ObjectRef {
     dest.push_back(value);
     AgregateImpl(dest, args...);
   }
+  /// \endcond
 
  public:
   // Array's own methods
@@ -986,7 +1045,10 @@ inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) 
{
   return std::move(lhs);
 }
 
-// Specialize make_object<ArrayObj> to make sure it is correct.
+/*!
+ * \brief Specialize make_object<ArrayObj>
+ * \return The empty array object.
+ */
 template <>
 inline ObjectPtr<ArrayObj> make_object() {
   return ArrayObj::Empty();
@@ -1079,8 +1141,6 @@ inline constexpr bool type_contains_v<Array<T>, Array<U>> 
= type_contains_v<T, U
 
 }  // namespace ffi
 
-// Expose to the tvm namespace
-// Rationale: convinience and no ambiguity
 using ffi::Array;
 }  // namespace tvm
 #endif  // TVM_FFI_CONTAINER_ARRAY_H_
diff --git a/ffi/include/tvm/ffi/container/map.h 
b/ffi/include/tvm/ffi/container/map.h
index b1ca4f805e..8103c447c1 100644
--- a/ffi/include/tvm/ffi/container/map.h
+++ b/ffi/include/tvm/ffi/container/map.h
@@ -40,12 +40,14 @@
 namespace tvm {
 namespace ffi {
 
+/// \cond Doxygen_Suppress
 #if TVM_FFI_DEBUG_WITH_ABI_CHANGE
 #define TVM_FFI_MAP_FAIL_IF_CHANGED() \
   TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent 
modification of the Map";
 #else
 #define TVM_FFI_MAP_FAIL_IF_CHANGED()
 #endif  // TVM_FFI_DEBUG_WITH_ABI_CHANGE
+/// \endcond
 
 /*! \brief Shared content of all specializations of hash map */
 class MapObj : public Object {
@@ -56,24 +58,28 @@ class MapObj : public Object {
   using mapped_type = Any;
   /*! \brief Type of value stored in the hash map */
   using KVType = std::pair<Any, Any>;
+  /// \cond Doxygen_Suppress
   /*! \brief Type of raw storage of the key-value pair in the hash map */
   struct KVRawStorageType {
     TVMFFIAny first;
     TVMFFIAny second;
   };
+  /// \endcond
   /*! \brief Iterator class */
   class iterator;
 
   static_assert(std::is_standard_layout<KVType>::value, "KVType is not 
standard layout");
   static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect");
 
+  /// \cond Doxygen_Suppress
   static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap;
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFIMap;
   static const constexpr bool _type_final = true;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(MapObj, Object);
+  /// \endcond
 
   /*!
-   * \brief Number of elements in the SmallMapObj
+   * \brief Number of elements in the MapObj
    * \return The result
    */
   size_t size() const { return size_; }
@@ -116,6 +122,7 @@ class MapObj : public Object {
    */
   void erase(const key_type& key) { erase(find(key)); }
 
+  /// \cond Doxygen_Suppress
   class iterator {
    public:
     using iterator_category = std::forward_iterator_tag;
@@ -180,6 +187,7 @@ class MapObj : public Object {
     friend class DenseMapObj;
     friend class SmallMapObj;
   };
+  /// \endcond
   /*!
    * \brief Create an empty container
    * \return The object created
@@ -1206,6 +1214,7 @@ class DenseMapObj : public MapObj {
   }
 };
 
+/// \cond
 #define TVM_FFI_DISPATCH_MAP(base, var, body) \
   {                                           \
     using TSmall = SmallMapObj*;              \
@@ -1280,6 +1289,7 @@ inline MapObj::iterator MapObj::find(const 
MapObj::key_type& key) const {
 inline void MapObj::erase(const MapObj::iterator& position) {
   TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); });
 }
+/// \endcond
 
 #undef TVM_FFI_DISPATCH_MAP
 #undef TVM_FFI_DISPATCH_MAP_CONST
@@ -1365,8 +1375,11 @@ template <typename K, typename V,
                                                details::storage_enabled_v<V>>>
 class Map : public ObjectRef {
  public:
+  /*! \brief The key type of the map */
   using key_type = K;
+  /*! \brief The mapped type of the map */
   using mapped_type = V;
+  /*! \brief The iterator type of the map */
   class iterator;
   /*!
    * \brief default constructor
@@ -1383,24 +1396,52 @@ class Map : public ObjectRef {
    */
   Map(const Map<K, V>& other) : ObjectRef(other.data_) {}
 
+  /*!
+   * \brief Move constructor
+   * \param other The other map
+   * \tparam KU The key type of the other map
+   * \tparam VU The mapped type of the other map
+   */
   template <typename KU, typename VU,
             typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                         details::type_contains_v<V, VU>>>
   Map(Map<KU, VU>&& other) : ObjectRef(std::move(other.data_)) {}
 
+  /*!
+   * \brief Copy constructor
+   * \param other The other map
+   * \tparam KU The key type of the other map
+   * \tparam VU The mapped type of the other map
+   */
   template <typename KU, typename VU,
             typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                         details::type_contains_v<V, VU>>>
   Map(const Map<KU, VU>& other) : ObjectRef(other.data_) {}
+
+  /*!
+   * \brief Move assignment
+   * \param other The other map
+   */
   Map<K, V>& operator=(Map<K, V>&& other) {
     data_ = std::move(other.data_);
     return *this;
   }
+
+  /*!
+   * \brief Copy assignment
+   * \param other The other map
+   */
   Map<K, V>& operator=(const Map<K, V>& other) {
     data_ = other.data_;
     return *this;
   }
 
+  /*!
+   * \brief Move assignment
+   * \param other The other map
+   * \tparam KU The key type of the other map
+   * \tparam VU The mapped type of the other map
+   */
   template <typename KU, typename VU,
             typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                         details::type_contains_v<V, VU>>>
@@ -1409,6 +1450,12 @@ class Map : public ObjectRef {
     return *this;
   }
 
+  /*!
+   * \brief Copy assignment
+   * \param other The other map
+   * \tparam KU The key type of the other map
+   * \tparam VU The mapped type of the other map
+   */
   template <typename KU, typename VU,
             typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                         details::type_contains_v<V, VU>>>
@@ -1502,6 +1549,11 @@ class Map : public ObjectRef {
     }
     return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(iter->second);
   }
+
+  /*!
+   * \brief Erase the entry associated with the key
+   * \param key The key
+   */
   void erase(const K& key) { CopyOnWrite()->erase(key); }
 
   /*!
@@ -1523,6 +1575,7 @@ class Map : public ObjectRef {
   /*! \brief specify container node */
   using ContainerType = MapObj;
 
+  /// \cond Doxygen_Suppress
   /*! \brief Iterator of the hash map */
   class iterator {
    public:
@@ -1579,6 +1632,7 @@ class Map : public ObjectRef {
 
     MapObj::iterator itr;
   };
+  /// \endcond
 
  private:
   /*! \brief Return data_ as type of pointer of MapObj */
@@ -1702,8 +1756,6 @@ inline constexpr bool type_contains_v<Map<K, V>, Map<KU, 
VU>> =
 
 }  // namespace ffi
 
-// Expose to the tvm namespace
-// Rationale: convinience and no ambiguity
 using ffi::Map;
 }  // namespace tvm
 #endif  // TVM_FFI_CONTAINER_MAP_H_
diff --git a/ffi/include/tvm/ffi/container/shape.h 
b/ffi/include/tvm/ffi/container/shape.h
index 28f4961c99..39c3ec2739 100644
--- a/ffi/include/tvm/ffi/container/shape.h
+++ b/ffi/include/tvm/ffi/container/shape.h
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file tvm/ffi/shape.h
+ * \file tvm/ffi/container/shape.h
  * \brief Container to store shape of an Tensor.
  */
 #ifndef TVM_FFI_CONTAINER_SHAPE_H_
@@ -39,6 +39,7 @@ namespace ffi {
 /*! \brief An object representing a shape tuple. */
 class ShapeObj : public Object, public TVMFFIShapeCell {
  public:
+  /*! \brief The type of shape index element. */
   using index_type = int64_t;
 
   /*! \brief Get "numel", meaning the number of elements of an array if the 
array has this shape */
@@ -50,9 +51,11 @@ class ShapeObj : public Object, public TVMFFIShapeCell {
     return product;
   }
 
+  /// \cond Doxygen_Suppress
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape;
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFIShape;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ShapeObj, Object);
+  /// \endcond
 };
 
 namespace details {
@@ -198,7 +201,9 @@ class Shape : public ObjectRef {
   /*! \return The product of the shape tuple */
   int64_t Product() const { return get()->Product(); }
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj);
+  /// \endcond
 };
 
 inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
diff --git a/ffi/include/tvm/ffi/container/tensor.h 
b/ffi/include/tvm/ffi/container/tensor.h
index 8a8134d860..b5be116b49 100644
--- a/ffi/include/tvm/ffi/container/tensor.h
+++ b/ffi/include/tvm/ffi/container/tensor.h
@@ -19,8 +19,8 @@
  */
 
 /*!
- * \file tvm/ffi/tensor.h
- * \brief Container to store an Tensor.
+ * \file tvm/ffi/container/tensor.h
+ * \brief Container to store a Tensor.
  */
 #ifndef TVM_FFI_CONTAINER_TENSOR_H_
 #define TVM_FFI_CONTAINER_TENSOR_H_
@@ -80,11 +80,11 @@ inline bool IsAligned(const DLTensor& arr, size_t 
alignment) {
 }
 
 /*!
- * \brief return the total number bytes needs to store packed data
+ * \brief return the total number of bytes needed to store packed data
  *
  * \param numel the number of elements in the array
  * \param dtype the data type of the array
- * \return the total number bytes needs to store packed data
+ * \return the total number of bytes needed to store packed data
  */
 inline size_t GetDataSize(int64_t numel, DLDataType dtype) {
   // compatible handling sub-byte uint1(bool), which usually stored as uint8_t
@@ -97,10 +97,10 @@ inline size_t GetDataSize(int64_t numel, DLDataType dtype) {
 }
 
 /*!
- * \brief return the size of data the DLTensor hold, in term of number of bytes
+ * \brief return the size of data the DLTensor holds, in terms of number of 
bytes
  *
  *  \param arr the input DLTensor
- *  \return number of  bytes of data in the DLTensor.
+ *  \return number of bytes of data in the DLTensor.
  */
 inline size_t GetDataSize(const DLTensor& arr) {
   size_t size = 1;
@@ -110,15 +110,17 @@ inline size_t GetDataSize(const DLTensor& arr) {
   return GetDataSize(size, arr.dtype);
 }
 
-/*! \brief An object representing an Tensor. */
+/*! \brief An object representing a Tensor. */
 class TensorObj : public Object, public DLTensor {
  public:
+  /// \cond Doxygen_Suppress
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFITensor;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TensorObj, Object);
+  /// \endcond
 
   /*!
-   * \brief Move Tensor to a DLPack managed tensor.
+   * \brief Move a Tensor to a DLPack managed tensor.
    * \return The converted DLPack managed tensor.
    */
   DLManagedTensor* ToDLPack() const {
@@ -132,7 +134,7 @@ class TensorObj : public Object, public DLTensor {
   }
 
   /*!
-   * \brief Move  Tensor to a DLPack managed tensor.
+   * \brief Move a Tensor to a DLPack managed tensor.
    * \return The converted DLPack managed tensor.
    */
   DLManagedTensorVersioned* ToDLPackVersioned() const {
@@ -149,16 +151,25 @@ class TensorObj : public Object, public DLTensor {
   }
 
  protected:
-  // backs up the shape/strides
+  /*! \brief Internal data to back returning shape. */
   Optional<Shape> shape_data_;
+  /*! \brief Internal data to back returning strides. */
   Optional<Shape> strides_data_;
 
+  /*!
+   * \brief Deleter for DLManagedTensor.
+   * \param tensor The DLManagedTensor to be deleted.
+   */
   static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
     TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
     details::ObjectUnsafe::DecRefObjectHandle(obj);
     delete tensor;
   }
 
+  /*!
+   * \brief Deleter for DLManagedTensorVersioned.
+   * \param tensor The DLManagedTensorVersioned to be deleted.
+   */
   static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* 
tensor) {
     TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
     details::ObjectUnsafe::DecRefObjectHandle(obj);
@@ -166,6 +177,7 @@ class TensorObj : public Object, public DLTensor {
   }
 
   friend class Tensor;
+  /// \endcond
 };
 
 namespace details {
@@ -272,6 +284,7 @@ class Tensor : public ObjectRef {
    * \param shape The shape of the Tensor.
    * \param dtype The data type of the Tensor.
    * \param device The device of the Tensor.
+   * \param extra_args Extra arguments to be forwarded to TNDAlloc.
    * \return The created Tensor.
    * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free.
    * \tparam ExtraArgs Extra arguments to be passed to Alloc.
@@ -337,7 +350,9 @@ class Tensor : public ObjectRef {
    */
   DLManagedTensorVersioned* ToDLPackVersioned() const { return 
get_mutable()->ToDLPackVersioned(); }
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
+  /// \endcond
 
  protected:
   /*!
diff --git a/ffi/include/tvm/ffi/container/tuple.h 
b/ffi/include/tvm/ffi/container/tuple.h
index be7e63fd94..0cb80b963e 100644
--- a/ffi/include/tvm/ffi/container/tuple.h
+++ b/ffi/include/tvm/ffi/container/tuple.h
@@ -45,33 +45,69 @@ class Tuple : public ObjectRef {
  public:
   static_assert(details::all_storage_enabled_v<Types...>,
                 "All types used in Tuple<...> must be compatible with Any");
-
+  /*! \brief Default constructor */
   Tuple() : ObjectRef(MakeDefaultTupleNode()) {}
+  /*! \brief Copy constructor */
   Tuple(const Tuple<Types...>& other) : ObjectRef(other) {}
+  /*! \brief Move constructor */
   Tuple(Tuple<Types...>&& other) : ObjectRef(std::move(other)) {}
+  /*!
+   * \brief Constructor from another tuple
+   * \param other The other tuple
+   * \tparam UTypes The types of the other tuple
+   * \tparam The enable_if_t type
+   */
   template <typename... UTypes,
             typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...), int>>
   Tuple(const Tuple<UTypes...>& other) : ObjectRef(other) {}
+
+  /*!
+   * \brief Constructor from another tuple
+   * \param other The other tuple
+   * \tparam UTypes The types of the other tuple
+   * \tparam The enable_if_t type
+   */
   template <typename... UTypes,
             typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...), int>>
   Tuple(Tuple<UTypes...>&& other) : ObjectRef(std::move(other)) {}
 
+  /*!
+   * \brief Constructor from arguments
+   * \param args The arguments
+   * \tparam UTypes The types of the other tuple
+   */
   template <typename... UTypes, typename = std::enable_if_t<
                                     sizeof...(Types) == sizeof...(UTypes) &&
                                     !(sizeof...(Types) == 1 &&
                                       (std::is_same_v<std::decay_t<UTypes>, 
Tuple<Types>> && ...))>>
   explicit Tuple(UTypes&&... args) : 
ObjectRef(MakeTupleNode(std::forward<UTypes>(args)...)) {}
 
+  /*!
+   * \brief Assignment from another tuple
+   * \param other The other tuple
+   * \tparam The enable_if_t type
+   */
   TVM_FFI_INLINE Tuple& operator=(const Tuple<Types...>& other) {
     data_ = other.data_;
     return *this;
   }
 
+  /*!
+   * \brief Assignment from another tuple
+   * \param other The other tuple
+   * \tparam The enable_if_t type
+   */
   TVM_FFI_INLINE Tuple& operator=(Tuple<Types...>&& other) {
     data_ = std::move(other.data_);
     return *this;
   }
 
+  /*!
+   * \brief Assignment from another tuple
+   * \param other The other tuple
+   * \tparam UTypes The types of the other tuple
+   * \tparam The enable_if_t type
+   */
   template <typename... UTypes,
             typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...)>>
   TVM_FFI_INLINE Tuple& operator=(const Tuple<UTypes...>& other) {
@@ -79,6 +115,12 @@ class Tuple : public ObjectRef {
     return *this;
   }
 
+  /*!
+   * \brief Assignment from another tuple
+   * \param other The other tuple
+   * \tparam UTypes The types of the other tuple
+   * \tparam The enable_if_t type
+   */
   template <typename... UTypes,
             typename = std::enable_if_t<(details::type_contains_v<Types, 
UTypes> && ...)>>
   TVM_FFI_INLINE Tuple& operator=(Tuple<UTypes...>&& other) {
@@ -86,7 +128,12 @@ class Tuple : public ObjectRef {
     return *this;
   }
 
-  explicit Tuple(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief Constructor ObjectPtr
+   * \param ptr The ObjectPtr
+   * \tparam The enable_if_t type
+   */
+  explicit Tuple(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
 
   /*!
    * \brief Get I-th element of the tuple
diff --git a/ffi/include/tvm/ffi/container/variant.h 
b/ffi/include/tvm/ffi/container/variant.h
index ee1f8316d8..5bea42cb05 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -102,6 +102,7 @@ class VariantBase<true> : public ObjectRef {
 template <typename... V>
 class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
  public:
+  /// \cond Doxygen_Suppress
   using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
   static_assert(details::all_storage_enabled_v<V...>,
                 "All types used in Variant<...> must be compatible with Any");
@@ -113,34 +114,63 @@ class Variant : public 
details::VariantBase<details::all_object_ref_v<V...>> {
   /* \brief Helper utility for SFINAE if the type is part of the variant */
   template <typename T>
   using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;
-
+  /// \endcond
+  /*!
+   * \brief Constructor from another variant
+   * \param other The other variant
+   */
   Variant(const Variant<V...>& other) : TParent(other.data_) {}
+  /*!
+   * \brief Constructor from another variant
+   * \param other The other variant
+   */
   Variant(Variant<V...>&& other) : TParent(std::move(other.data_)) {}
 
+  /*!
+   * \brief Assignment from another variant
+   * \param other The other variant
+   */
   TVM_FFI_INLINE Variant& operator=(const Variant<V...>& other) {
     this->SetData(other.data_);
     return *this;
   }
 
+  /*!
+   * \brief Assignment from another variant
+   * \param other The other variant
+   */
   TVM_FFI_INLINE Variant& operator=(Variant<V...>&& other) {
     this->SetData(std::move(other.data_));
     return *this;
   }
 
+  /*!
+   * \brief Constructor from another variant
+   * \param other The other variant
+   */
   template <typename T, typename = enable_if_variant_contains_t<T>>
   Variant(T other) : TParent(std::move(other)) {}  // NOLINT(*)
 
+  /*!
+   * \brief Assignment from another variant
+   * \param other The other variant
+   */
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE Variant& operator=(T other) {
     return operator=(Variant(std::move(other)));
   }
 
+  /*!
+   * \brief Try to cast to a type T, return std::nullopt if the cast is not 
possible.
+   * \return The casted value, or std::nullopt if the cast is not possible.
+   * \tparam T The type to cast to.
+   */
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE std::optional<T> as() const {
     return this->TParent::ToAnyView().template as<T>();
   }
 
-  /*
+  /*!
    * \brief Shortcut of as Object to cast to a const pointer when T is an 
Object.
    *
    * \tparam T The object type.
@@ -151,16 +181,30 @@ class Variant : public 
details::VariantBase<details::all_object_ref_v<V...>> {
     return this->TParent::ToAnyView().template as<const 
T*>().value_or(nullptr);
   }
 
+  /*!
+   * \brief Get the value of the variant in type T, throws an exception if 
cast fails.
+   * \return The value of the variant
+   * \tparam T The type to get.
+   */
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE T get() const& {
     return this->TParent::ToAnyView().template cast<T>();
   }
 
+  /*!
+   * \brief Get the value of the variant in type T, throws an exception if 
cast fails.
+   * \return The value of the variant
+   * \tparam T The type to get.
+   */
   template <typename T, typename = enable_if_variant_contains_t<T>>
   TVM_FFI_INLINE T get() && {
     return std::move(*this).TParent::MoveToAny().template cast<T>();
   }
 
+  /*!
+   * \brief Get the type key of the variant
+   * \return The type key of the variant
+   */
   TVM_FFI_INLINE std::string GetTypeKey() const { return 
this->TParent::ToAnyView().GetTypeKey(); }
 
  private:
@@ -255,8 +299,6 @@ inline constexpr bool type_contains_v<Variant<V...>, T> = 
(type_contains_v<V, T>
 }  // namespace details
 }  // namespace ffi
 
-// Expose to the tvm namespace
-// Rationale: convinience and no ambiguity
 using ffi::Variant;
 }  // namespace tvm
 #endif  // TVM_FFI_CONTAINER_VARIANT_H_
diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h
index c153d71cb7..8da30dc5d6 100644
--- a/ffi/include/tvm/ffi/dtype.h
+++ b/ffi/include/tvm/ffi/dtype.h
@@ -39,7 +39,7 @@ namespace ffi {
  *
  * This class is always consistent with the DLPack.
  *
- * TOTO(tvm-team): update to latest DLPack types.
+ * TODO(tvm-team): update to latest DLPack types.
  */
 enum DLExtDataTypeCode { kDLExtCustomBegin = 129 };
 
@@ -113,6 +113,11 @@ inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode 
type_code) {  // NOLINT(*
 }
 }  // namespace details
 
+/*!
+ * \brief Convert a string to a DLDataType.
+ * \param str The string to convert.
+ * \return The DLDataType.
+ */
 inline DLDataType StringToDLDataType(const String& str) {
   DLDataType out;
   TVMFFIByteArray data{str.data(), str.size()};
@@ -120,6 +125,11 @@ inline DLDataType StringToDLDataType(const String& str) {
   return out;
 }
 
+/*!
+ * \brief Convert a DLDataType to a string.
+ * \param dtype The DLDataType to convert.
+ * \return The string.
+ */
 inline String DLDataTypeToString(DLDataType dtype) {
   TVMFFIAny out;
   TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h
index 97311b988c..78dfe5ed5a 100644
--- a/ffi/include/tvm/ffi/error.h
+++ b/ffi/include/tvm/ffi/error.h
@@ -64,7 +64,7 @@ namespace ffi {
  *  This error can be thrown by EnvCheckSignals to indicate
  *  that there is an error set in the frontend environment(e.g.
  *  python interpreter). The TVM FFI should catch this error
- *  and return a proper code tell the frontend caller about
+ *  and return a proper code to tell the frontend caller about
  *  this fact.
  *
  * \code
@@ -85,10 +85,11 @@ struct EnvErrorAlreadySet : public std::exception {};
  */
 class ErrorObj : public Object, public TVMFFIErrorCell {
  public:
+  /// \cond Doxygen_Suppress
   static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError;
   static constexpr const char* _type_key = "ffi.Error";
-
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object);
+  /// \endcond
 };
 
 namespace details {
@@ -125,33 +126,65 @@ class ErrorObjFromStd : public ErrorObj {
  */
 class Error : public ObjectRef, public std::exception {
  public:
+  /*!
+   * \brief Constructor
+   * \param kind The kind of the error.
+   * \param message The message of the error.
+   * \param traceback The traceback of the error.
+   */
   Error(std::string kind, std::string message, std::string traceback) {
     data_ = make_object<details::ErrorObjFromStd>(kind, message, traceback);
   }
 
+  /*!
+   * \brief Constructor
+   * \param kind The kind of the error.
+   * \param message The message of the error.
+   * \param traceback The traceback of the error.
+   */
   Error(std::string kind, std::string message, const TVMFFIByteArray* 
traceback)
       : Error(kind, message, std::string(traceback->data, traceback->size)) {}
 
+  /*!
+   * \brief Get the kind of the error object.
+   * \return The kind of the error object.
+   */
   std::string kind() const {
     ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
     return std::string(obj->kind.data, obj->kind.size);
   }
 
+  /*!
+   * \brief Get the message of the error object.
+   * \return The message of the error object.
+   */
   std::string message() const {
     ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
     return std::string(obj->message.data, obj->message.size);
   }
 
+  /*!
+   * \brief Get the traceback of the error object.
+   * \return The traceback of the error object.
+   */
   std::string traceback() const {
     ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
     return std::string(obj->traceback.data, obj->traceback.size);
   }
 
+  /*!
+   * \brief Update the traceback of the error object.
+   * \param traceback_str The traceback to update.
+   */
   void UpdateTraceback(const TVMFFIByteArray* traceback_str) {
     ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
     obj->update_traceback(obj, traceback_str);
   }
 
+  /*!
+   * \brief Get the error message
+   * \return The error message
+   */
   const char* what() const noexcept(true) override {
     thread_local std::string what_data;
     ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
@@ -162,7 +195,9 @@ class Error : public ObjectRef, public std::exception {
     return what_data.c_str();
   }
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj);
+  /// \endcond
 };
 
 namespace details {
diff --git a/ffi/include/tvm/ffi/extra/base64.h 
b/ffi/include/tvm/ffi/extra/base64.h
index 136fec2e7f..da763cfe3a 100644
--- a/ffi/include/tvm/ffi/extra/base64.h
+++ b/ffi/include/tvm/ffi/extra/base64.h
@@ -80,7 +80,7 @@ inline String Base64Encode(const Bytes& data) {
 
 /*!
  * \brief Decode a base64 string into a byte array
- * \param data The base64 encoded string to decode
+ * \param bytes The bytes to be decoded
  * \return The decoded byte array
  */
 inline Bytes Base64Decode(TVMFFIByteArray bytes) {
diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h 
b/ffi/include/tvm/ffi/extra/c_env_api.h
index 17cb3af6d0..6f8e44bdfb 100644
--- a/ffi/include/tvm/ffi/extra/c_env_api.h
+++ b/ffi/include/tvm/ffi/extra/c_env_api.h
@@ -34,6 +34,9 @@ extern "C" {
 // Focusing on minimalistic thread-local context recording stream being used.
 // We explicitly not handle allocation/de-allocation of stream here.
 // ----------------------------------------------------------------------------
+/*!
+ * \brief The type of the stream handle.
+ */
 typedef void* TVMFFIStreamHandle;
 
 /*!
@@ -91,7 +94,7 @@ TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const char* name, void* 
symbol);
 TVM_FFI_DLL int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, 
const char* func_name,
                                               TVMFFIObjectHandle* out);
 
-/*
+/*!
  * \brief Register a symbol value that will be initialized when a library with 
the symbol is loaded.
  *
  * This function can be used to make context functions to be available in the 
library
diff --git a/ffi/include/tvm/ffi/extra/json.h b/ffi/include/tvm/ffi/extra/json.h
index 409f7aa525..24ab2f0d89 100644
--- a/ffi/include/tvm/ffi/extra/json.h
+++ b/ffi/include/tvm/ffi/extra/json.h
@@ -54,7 +54,7 @@ using Array = ffi::Array<Any>;
  * \brief Parse a JSON string into an Any value.
  *
  * Besides the standard JSON syntax, this function also supports:
- * - Infinity/NaN as javascript syntax
+ * - Infinity/NaN as JavaScript syntax
  * - int64 integer value
  *
  * If error_msg is not nullptr, the error message will be written to it
diff --git a/ffi/include/tvm/ffi/extra/module.h 
b/ffi/include/tvm/ffi/extra/module.h
index 1af2c2b6b2..89e0c287a3 100644
--- a/ffi/include/tvm/ffi/extra/module.h
+++ b/ffi/include/tvm/ffi/extra/module.h
@@ -17,7 +17,7 @@
  * under the License.
  */
 /*!
- * \file tvm/ffi/module.h
+ * \file tvm/ffi/extra/module.h
  * \brief A managed dynamic module in the TVM FFI.
  */
 #ifndef TVM_FFI_EXTRA_MODULE_H_
@@ -130,6 +130,7 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object {
   /*!
    * \brief Get the function metadata of the function if available.
    * \param name The name of the function.
+   * \param query_imports Whether to query imported modules.
    * \return The function metadata of the function in json format.
    */
   Optional<String> GetFunctionMetadata(const String& name, bool query_imports);
@@ -142,10 +143,12 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object {
 
   struct InternalUnsafe;
 
+  /// \cond Doxygen_Suppress
   static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule;
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFIModule;
   static const constexpr bool _type_final = true;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleObj, Object);
+  /// \endcond
 
  protected:
   friend struct InternalUnsafe;
@@ -203,12 +206,11 @@ class Module : public ObjectRef {
   /*!
    * \brief Load a module from file.
    * \param file_name The name of the host function module.
-   * \param format The format of the file.
    * \note This function won't load the import relationship.
    *  Re-create import relationship by calling Import.
    */
   TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name);
-  /*
+  /*!
    * \brief Query context symbols that is registered via TVMEnvRegisterSymbols.
    * \param callback The callback to be called with the symbol name and 
address.
    * \note This helper can be used to implement custom Module that needs to 
access context symbols.
@@ -216,7 +218,9 @@ class Module : public ObjectRef {
   TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols(
       const ffi::TypedFunction<void(String, void*)>& callback);
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Module, ObjectRef, 
ModuleObj);
+  /// \endcond
 };
 
 /*
diff --git a/ffi/include/tvm/ffi/extra/serialization.h 
b/ffi/include/tvm/ffi/extra/serialization.h
index c08ad81cc3..b5aa2891ac 100644
--- a/ffi/include/tvm/ffi/extra/serialization.h
+++ b/ffi/include/tvm/ffi/extra/serialization.h
@@ -34,7 +34,7 @@ namespace ffi {
  *
  * The JSON graph structure is stored as follows:
  *
- * ```json
+ * ```
  * {
  *   "root_index": <int>,        // Index of root node in nodes array
  *   "nodes": [<node>, ...],     // Array of serialized nodes
diff --git a/ffi/include/tvm/ffi/extra/structural_equal.h 
b/ffi/include/tvm/ffi/extra/structural_equal.h
index 8eb5da7f67..ec960a85e6 100644
--- a/ffi/include/tvm/ffi/extra/structural_equal.h
+++ b/ffi/include/tvm/ffi/extra/structural_equal.h
@@ -30,7 +30,7 @@
 
 namespace tvm {
 namespace ffi {
-/*
+/*!
  * \brief Structural equality comparators
  */
 class StructuralEqual {
diff --git a/ffi/include/tvm/ffi/extra/structural_hash.h 
b/ffi/include/tvm/ffi/extra/structural_hash.h
index 1d7ba2613e..bfe023c382 100644
--- a/ffi/include/tvm/ffi/extra/structural_hash.h
+++ b/ffi/include/tvm/ffi/extra/structural_hash.h
@@ -29,7 +29,7 @@
 namespace tvm {
 namespace ffi {
 
-/*
+/*!
  * \brief Structural hash
  */
 class StructuralHash {
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index f84978800e..884e46fa44 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -40,8 +40,16 @@ namespace ffi {
 /**
  * Helper macro to construct a safe call
  *
- * \brief Marks the begining of the safe call that catches exception explicitly
+ * \brief Marks the beginning of the safe call that catches exception 
explicitly
+ * \sa TVM_FFI_SAFE_CALL_END
  *
+ * \code
+ * int TVMFFICStyleFunction() {
+ *   TVM_FFI_SAFE_CALL_BEGIN();
+ *   // c++ code region here
+ *   TVM_FFI_SAFE_CALL_END();
+ * }
+ * \endcode
  */
 #define TVM_FFI_SAFE_CALL_BEGIN() \
   try {                           \
@@ -66,6 +74,15 @@ namespace ffi {
   }                                                                            
                \
   TVM_FFI_UNREACHABLE()
 
+/*!
+ * \brief Macro to check a call to TVMFFISafeCallType and raise exception if 
error happens.
+ * \param func The function to check.
+ *
+ * \code
+ * // calls TVMFFIFunctionCall and raises exception if error happens
+ * TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index));
+ * \endcode
+ */
 #define TVM_FFI_CHECK_SAFE_CALL(func)                      \
   {                                                        \
     int ret_code = (func);                                 \
@@ -79,28 +96,34 @@ namespace ffi {
 
 /*!
  * \brief Object container class that backs ffi::Function
- * \note Do not use this function directly, use ffi::Function
+ * \note Do not use this class directly, use ffi::Function
  */
 class FunctionObj : public Object, public TVMFFIFunctionCell {
  public:
+  /*! \brief Typedef for C++ style calling signature that comes with exception 
propagation */
   typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*);
   using TVMFFIFunctionCell::safe_call;
-  /*! \brief A C++ style call implementation, with exception propagation in 
c++ style. */
+  /*! \brief A C++ style call implementation, with exception propagation in 
C++ style. */
   FCall call;
-
+  /*!
+   * \brief Call the function in packed format.
+   * \param args The arguments
+   * \param num_args The number of arguments
+   * \param result The return value.
+   */
   TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* 
result) const {
     this->call(this, args, num_args, result);
   }
-
+  /// \cond Doxygen_Suppress
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction;
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFIFunction;
-
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(FunctionObj, Object);
+  /// \endcond
 
  protected:
   /*! \brief Make default constructor protected. */
   FunctionObj() {}
-
+  /// \cond Doxygen_Suppress
   // Implementing safe call style
   static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, 
TVMFFIAny* result) {
     TVM_FFI_SAFE_CALL_BEGIN();
@@ -110,7 +133,7 @@ class FunctionObj : public Object, public 
TVMFFIFunctionCell {
                reinterpret_cast<Any*>(result));
     TVM_FFI_SAFE_CALL_END();
   }
-
+  /// \endcond
   friend class Function;
 };
 
@@ -118,7 +141,7 @@ namespace details {
 /*!
  * \brief Derived object class for constructing FunctionObj backed by a 
TCallable
  *
- * This is a helper class that
+ * This is a helper class that implements the function call interface.
  */
 template <typename TCallable>
 class FunctionObjImpl : public FunctionObj {
@@ -386,14 +409,32 @@ class Function : public ObjectRef {
     }
   }
 
+  /*!
+   * \brief Get global function by name
+   * \param name The name of the function
+   * \return The global function
+   * \note This function will return std::nullopt if the function is not found.
+   */
   static std::optional<Function> GetGlobal(const std::string& name) {
     return GetGlobal(std::string_view(name.data(), name.length()));
   }
 
+  /*!
+   * \brief Get global function by name
+   * \param name The name of the function
+   * \return The global function
+   * \note This function will return std::nullopt if the function is not found.
+   */
   static std::optional<Function> GetGlobal(const String& name) {
     return GetGlobal(std::string_view(name.data(), name.length()));
   }
 
+  /*!
+   * \brief Get global function by name
+   * \param name The name of the function
+   * \return The global function
+   * \note This function will return std::nullopt if the function is not found.
+   */
   static std::optional<Function> GetGlobal(const char* name) {
     return GetGlobal(std::string_view(name));
   }
@@ -411,14 +452,32 @@ class Function : public ObjectRef {
     return *res;
   }
 
+  /*!
+   * \brief Get global function by name
+   * \param name The name of the function
+   * \return The global function
+   * \note This function will throw an error if the function is not found.
+   */
   static Function GetGlobalRequired(const std::string& name) {
     return GetGlobalRequired(std::string_view(name.data(), name.length()));
   }
 
+  /*!
+   * \brief Get global function by name
+   * \param name The name of the function
+   * \return The global function
+   * \note This function will throw an error if the function is not found.
+   */
   static Function GetGlobalRequired(const String& name) {
     return GetGlobalRequired(std::string_view(name.data(), name.length()));
   }
 
+  /*!
+   * \brief Get global function by name
+   * \param name The name of the function
+   * \return The global function
+   * \note This function will throw an error if the function is not found.
+   */
   static Function GetGlobalRequired(const char* name) {
     return GetGlobalRequired(std::string_view(name));
   }
@@ -514,7 +573,8 @@ class Function : public ObjectRef {
   /*!
    * \brief Call the function in packed format.
    * \param args The arguments
-   * \param rv The return value.
+   * \param num_args The number of arguments
+   * \param result The return value.
    */
   TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* 
result) const {
     static_cast<FunctionObj*>(data_.get())->CallPacked(args, num_args, result);
@@ -533,7 +593,9 @@ class Function : public ObjectRef {
   /*! \return Whether the packed function is not nullptr */
   TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != 
nullptr; }
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj);
+  /// \endcond
 
   class Registry;
 
diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h
index 533d000427..2e4f3cd6b4 100644
--- a/ffi/include/tvm/ffi/memory.h
+++ b/ffi/include/tvm/ffi/memory.h
@@ -33,16 +33,7 @@ namespace tvm {
 namespace ffi {
 
 /*! \brief Deleter function for obeject */
-typedef void (*FObjectDeleter)(TVMFFIObject* obj, int flags);
-
-/*!
- * \brief Allocate an object using default allocator.
- * \param args arguments to the constructor.
- * \tparam T the node type.
- * \return The ObjectPtr to the allocated object.
- */
-template <typename T, typename... Args>
-inline ObjectPtr<T> make_object(Args&&... args);
+typedef void (*FObjectDeleter)(void* obj, int flags);
 
 // Detail implementations after this
 //
@@ -53,7 +44,7 @@ inline ObjectPtr<T> make_object(Args&&... args);
 // - Arena allocator that gives ownership of memory to arena (deleter = 
nullptr)
 // - Thread-local object pools: one pool per size and alignment requirement.
 // - Can specialize by type of object to give the specific allocator to each 
object.
-
+namespace details {
 /*!
  * \brief Base class of object allocators that implements make.
  *  Use curiously recurring template pattern.
@@ -138,8 +129,9 @@ class SimpleObjAllocator : public 
ObjAllocatorBase<SimpleObjAllocator> {
     static FObjectDeleter Deleter() { return Deleter_; }
 
    private:
-    static void Deleter_(TVMFFIObject* objptr, int flags) {
-      T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<T>(objptr);
+    static void Deleter_(void* objptr, int flags) {
+      T* tptr =
+          
details::ObjectUnsafe::RawObjectPtrFromUnowned<T>(static_cast<TVMFFIObject*>(objptr));
       if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
         // It is important to do tptr->T::~T(),
         // so that we explicitly call the specific destructor
@@ -188,8 +180,9 @@ class SimpleObjAllocator : public 
ObjAllocatorBase<SimpleObjAllocator> {
     static FObjectDeleter Deleter() { return Deleter_; }
 
    private:
-    static void Deleter_(TVMFFIObject* objptr, int flags) {
-      ArrayType* tptr = 
details::ObjectUnsafe::RawObjectPtrFromUnowned<ArrayType>(objptr);
+    static void Deleter_(void* objptr, int flags) {
+      ArrayType* tptr = 
details::ObjectUnsafe::RawObjectPtrFromUnowned<ArrayType>(
+          static_cast<TVMFFIObject*>(objptr));
       if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
         // It is important to do tptr->ArrayType::~ArrayType(),
         // so that we explicitly call the specific destructor
@@ -204,22 +197,35 @@ class SimpleObjAllocator : public 
ObjAllocatorBase<SimpleObjAllocator> {
     }
   };
 };
+}  // namespace details
 
+/*!
+ * \brief Allocate an object
+ * \param args arguments to the constructor.
+ * \tparam T the node type.
+ * \return The ObjectPtr to the allocated object.
+ */
 template <typename T, typename... Args>
 inline ObjectPtr<T> make_object(Args&&... args) {
-  return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
+  return 
details::SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
 }
 
+/*!
+ * \brief Allocate an Object with additional ElemType[num_elems] that are 
stored right after.
+ * \param num_elems The number of elements in the array.
+ * \param args arguments to the constructor.
+ * \tparam ArrayType the array type.
+ * \tparam ElemType the element type.
+ * \return The ObjectPtr to the allocated array.
+ */
 template <typename ArrayType, typename ElemType, typename... Args>
 inline ObjectPtr<ArrayType> make_inplace_array_object(size_t num_elems, 
Args&&... args) {
-  return SimpleObjAllocator().make_inplace_array<ArrayType, 
ElemType>(num_elems,
-                                                                      
std::forward<Args>(args)...);
+  return details::SimpleObjAllocator().make_inplace_array<ArrayType, ElemType>(
+      num_elems, std::forward<Args>(args)...);
 }
 
 }  // namespace ffi
 
-// Export the make_object function
-// rationale: ease of use, and no ambiguity
 using ffi::make_object;
 }  // namespace tvm
 #endif  // TVM_FFI_MEMORY_H_
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index ab0e424551..c1ab9d16d9 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -34,34 +34,63 @@
 namespace tvm {
 namespace ffi {
 
+/*!
+ * \brief TypeIndex enum, alias of TVMFFITypeIndex.
+ */
 using TypeIndex = TVMFFITypeIndex;
+
+/*!
+ * \brief TypeInfo, alias of TVMFFITypeInfo.
+ */
 using TypeInfo = TVMFFITypeInfo;
 
 /*!
  * \brief Known type keys for pre-defined types.
  */
 struct StaticTypeKey {
+  /*! \brief The type key for Any */
   static constexpr const char* kTVMFFIAny = "Any";
+  /*! \brief The type key for None */
   static constexpr const char* kTVMFFINone = "None";
+  /*! \brief The type key for bool */
   static constexpr const char* kTVMFFIBool = "bool";
+  /*! \brief The type key for int */
   static constexpr const char* kTVMFFIInt = "int";
+  /*! \brief The type key for float */
   static constexpr const char* kTVMFFIFloat = "float";
+  /*! \brief The type key for void* */
   static constexpr const char* kTVMFFIOpaquePtr = "void*";
+  /*! \brief The type key for DataType */
   static constexpr const char* kTVMFFIDataType = "DataType";
+  /*! \brief The type key for Device */
   static constexpr const char* kTVMFFIDevice = "Device";
+  /*! \brief The type key for const char* */
   static constexpr const char* kTVMFFIRawStr = "const char*";
+  /*! \brief The type key for TVMFFIByteArray* */
   static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*";
+  /*! \brief The type key for ObjectRValueRef */
   static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef";
+  /*! \brief The type key for SmallStr */
   static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr";
+  /*! \brief The type key for SmallBytes */
   static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes";
+  /*! \brief The type key for Bytes */
   static constexpr const char* kTVMFFIBytes = "ffi.Bytes";
+  /*! \brief The type key for String */
   static constexpr const char* kTVMFFIStr = "ffi.String";
+  /*! \brief The type key for Shape */
   static constexpr const char* kTVMFFIShape = "ffi.Shape";
+  /*! \brief The type key for Tensor */
   static constexpr const char* kTVMFFITensor = "ffi.Tensor";
+  /*! \brief The type key for Object */
   static constexpr const char* kTVMFFIObject = "ffi.Object";
+  /*! \brief The type key for Function */
   static constexpr const char* kTVMFFIFunction = "ffi.Function";
+  /*! \brief The type key for Array */
   static constexpr const char* kTVMFFIArray = "ffi.Array";
+  /*! \brief The type key for Map */
   static constexpr const char* kTVMFFIMap = "ffi.Map";
+  /*! \brief The type key for Module */
   static constexpr const char* kTVMFFIModule = "ffi.Module";
 };
 
@@ -95,7 +124,7 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t 
object_type_index);
 }  // namespace details
 
 /*!
- * \brief base class of all object containers.
+ * \brief Base class of all object containers.
  *
  * Sub-class of objects should declare the following static constexpr fields:
  *
@@ -189,11 +218,14 @@ class Object {
     return std::string(type_info->type_key.data, type_info->type_key.size);
   }
 
+  /*!
+   * \return Whether the object.use_count() == 1.
+   */
   bool unique() const { return use_count() == 1; }
 
   /*!
    * \return The usage count of the cell.
-   * \note We use stl style naming to be consistent with known API in 
shared_ptr.
+   * \note We use STL style naming to be consistent with known API in 
shared_ptr.
    */
   int32_t use_count() const {
     // only need relaxed load of counters
@@ -204,19 +236,26 @@ class Object {
 #endif
   }
 
-  // Information about the object
+  
//----------------------------------------------------------------------------
+  //  The following fields are configuration flags for subclasses of object
+  
//----------------------------------------------------------------------------
+  /*! \brief The type key of the class */
   static constexpr const char* _type_key = StaticTypeKey::kTVMFFIObject;
-
-  // Default object type properties for sub-classes
+  /*! \brief Whether the class is final */
   static constexpr bool _type_final = false;
+  /*! \brief Whether allow mutable access to fields */
   static constexpr bool _type_mutable = false;
+  /*! \brief The number of child slots of the class to pre-allocate to this 
type */
   static constexpr uint32_t _type_child_slots = 0;
+  /*!
+   * \brief Whether allow additional children beyond pre-specified by 
_type_child_slots
+   */
   static constexpr bool _type_child_slots_can_overflow = true;
-  // NOTE: static type index field of the class
+  /*! \brief The static type index of the class */
   static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject;
-  // the static type depth of the class
+  /*! \brief The static depth of the class in the object hierarchy */
   static constexpr int32_t _type_depth = 0;
-  // the structural equality and hash kind of the type
+  /*! \brief The structural equality and hash kind of the type */
   static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindUnsupported;
   // The following functions are provided by macro
   // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO
@@ -761,7 +800,7 @@ class ObjectRef {
 
   /*! \brief type indicate the container type. */
   using ContainerType = Object;
-  // Default type properties for the reference class.
+  /*! \brief Whether the reference can point to nullptr */
   static constexpr bool _type_is_nullable = true;
 
  protected:
@@ -804,7 +843,7 @@ struct ObjectPtrEqual {
   TVM_FFI_INLINE bool operator()(const Variant<V...>& a, const Variant<V...>& 
b) const;
 };
 
-// If dynamic type is enabled, we still need to register the runtime type of 
parent
+/// \cond Doxygen_Suppress
 #define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType)                
               \
   static constexpr int32_t _type_depth = ParentType::_type_depth + 1;          
               \
   static int32_t _GetOrAllocRuntimeTypeIndex() {                               
               \
@@ -820,6 +859,7 @@ struct ObjectPtrEqual {
     return tindex;                                                             
               \
   }                                                                            
               \
   static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex()
+/// \endcond
 
 /*!
  * \brief Helper macro to declare a object that comes with static type index.
@@ -862,7 +902,7 @@ struct ObjectPtrEqual {
   static const constexpr bool _type_final [[maybe_unused]] = true;   \
   TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
 
-/*
+/*!
  * \brief Define object reference methods.
  *
  * \param TypeName The object type name
@@ -880,7 +920,7 @@ struct ObjectPtrEqual {
   const ObjectName* get() const { return operator->(); }                       
                \
   using ContainerType = ObjectName
 
-/*
+/*!
  * \brief Define object reference methods do not have undefined state.
  *
  * \param TypeName The object type name
@@ -895,7 +935,7 @@ struct ObjectPtrEqual {
   static constexpr bool _type_is_nullable = false;                             
                \
   using ContainerType = ObjectName
 
-/*
+/*!
  * \brief Define object reference methods of whose content is mutable.
  * \param TypeName The object type name
  * \param ParentType The parent type of the objectref
@@ -910,7 +950,7 @@ struct ObjectPtrEqual {
   ObjectName* operator->() const { return 
static_cast<ObjectName*>(data_.get()); }          \
   using ContainerType = ObjectName
 
-/*
+/*!
  * \brief Define object reference methods that is both not nullable and 
mutable.
  *
  * \param TypeName The object type name
diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h
index a52f64e483..3f406d4181 100644
--- a/ffi/include/tvm/ffi/optional.h
+++ b/ffi/include/tvm/ffi/optional.h
@@ -38,7 +38,7 @@ namespace ffi {
 
 // Note: We place optional in tvm/ffi instead of tvm/ffi/container
 // because optional itself is an inherent core component of the FFI system.
-
+/// \cond Doxygen_Suppress
 template <typename T>
 inline constexpr bool is_optional_type_v = false;
 
@@ -50,6 +50,7 @@ inline constexpr bool is_optional_type_v<Optional<T>> = true;
 template <typename T>
 inline constexpr bool use_ptr_based_optional_v =
     (std::is_base_of_v<ObjectRef, T> && !is_optional_type_v<T>);
+/// \endcond
 
 // Specialization for non-ObjectRef types.
 // simply fallback to std::optional
@@ -410,7 +411,6 @@ class Optional<T, 
std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
 };
 }  // namespace ffi
 
-// Expose to the tvm namespace
 using ffi::Optional;
 }  // namespace tvm
 #endif  // TVM_FFI_OPTIONAL_H_
diff --git a/ffi/include/tvm/ffi/reflection/access_path.h 
b/ffi/include/tvm/ffi/reflection/access_path.h
index 267cb76fc1..c614d4ca28 100644
--- a/ffi/include/tvm/ffi/reflection/access_path.h
+++ b/ffi/include/tvm/ffi/reflection/access_path.h
@@ -37,14 +37,23 @@ namespace tvm {
 namespace ffi {
 namespace reflection {
 
+/*!
+ * \brief The kind of the access pattern.
+ */
 enum class AccessKind : int32_t {
+  /*! \brief Object attribute access. */
   kAttr = 0,
+  /*! \brief Array item access. */
   kArrayItem = 1,
+  /*! \brief Map item access. */
   kMapItem = 2,
   // the following two are used for error reporting when
   // the supposed access field is not available
+  /*! \brief Object attribute missing access. */
   kAttrMissing = 3,
+  /*! \brief Array item missing access. */
   kArrayItemMissing = 4,
+  /*! \brief Map item missing access. */
   kMapItemMissing = 5,
 };
 
@@ -68,6 +77,11 @@ class AccessStepObj : public Object {
 
   // default constructor to enable auto-serialization
   AccessStepObj() = default;
+  /*!
+   * \brief Constructor
+   * \param kind The kind of the access step.
+   * \param key The key of the access step.
+   */
   AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {}
 
   /*!
@@ -77,9 +91,11 @@ class AccessStepObj : public Object {
    */
   inline bool StepEqual(const AccessStep& other) const;
 
+  /// \cond Doxygen_Suppress
   static constexpr const char* _type_key = "ffi.reflection.AccessStep";
   static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindConstTreeNode;
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object);
+  /// \endcond
 };
 
 /*!
@@ -89,27 +105,65 @@ class AccessStepObj : public Object {
  */
 class AccessStep : public ObjectRef {
  public:
+  /*!
+   * \brief Constructor
+   * \param kind The kind of the access step.
+   * \param key The key of the access step.
+   * \return The access step.
+   */
   AccessStep(AccessKind kind, Any key) : 
ObjectRef(make_object<AccessStepObj>(kind, key)) {}
 
+  /*!
+   * \brief Create an access step for a object attribute access.
+   * \param field_name The name of the field to access.
+   * \return The access step.
+   */
   static AccessStep Attr(String field_name) { return 
AccessStep(AccessKind::kAttr, field_name); }
 
+  /*!
+   * \brief Create an access step for a object attribute missing access.
+   * \param field_name The name of the field to access.
+   * \return The access step.
+   */
   static AccessStep AttrMissing(String field_name) {
     return AccessStep(AccessKind::kAttrMissing, field_name);
   }
 
+  /*!
+   * \brief Create an access step for a array item access.
+   * \param index The index of the array item to access.
+   * \return The access step.
+   */
   static AccessStep ArrayItem(int64_t index) { return 
AccessStep(AccessKind::kArrayItem, index); }
 
+  /*!
+   * \brief Create an access step for a array item missing access.
+   * \param index The index of the array item to access.
+   * \return The access step.
+   */
   static AccessStep ArrayItemMissing(int64_t index) {
     return AccessStep(AccessKind::kArrayItemMissing, index);
   }
 
+  /*!
+   * \brief Create an access step for a map item access.
+   * \param key The key of the map item to access.
+   * \return The access step.
+   */
   static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, 
key); }
 
+  /*!
+   * \brief Create an access step for a map item missing access.
+   * \param key The key of the map item to access.
+   * \return The access step.
+   */
   static AccessStep MapItemMissing(Any key = nullptr) {
     return AccessStep(AccessKind::kMapItemMissing, key);
   }
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, 
AccessStepObj);
+  /// \endcond
 };
 
 inline bool AccessStepObj::StepEqual(const AccessStep& other) const {
@@ -231,9 +285,11 @@ class AccessPathObj : public Object {
    */
   inline bool IsPrefixOf(const AccessPath& other) const;
 
+  /// \cond Doxygen_Suppress
   static constexpr const char* _type_key = "ffi.reflection.AccessPath";
   static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindConstTreeNode;
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessPathObj, Object);
+  /// \endcond
 
  private:
   static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) {
@@ -301,9 +357,14 @@ class AccessPath : public ObjectRef {
     return AccessPath(make_object<AccessPathObj>(std::nullopt, std::nullopt, 
0));
   }
 
+  /// \cond Doxygen_Suppress
   TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef, 
AccessPathObj);
+  /// \endcond
 };
 
+/*!
+ * \brief The pair of access paths.
+ */
 using AccessPathPair = Tuple<AccessPath, AccessPath>;
 
 inline Optional<AccessPath> AccessPathObj::GetParent() const {
diff --git a/ffi/include/tvm/ffi/reflection/accessor.h 
b/ffi/include/tvm/ffi/reflection/accessor.h
index 5215444052..5fadd0985d 100644
--- a/ffi/include/tvm/ffi/reflection/accessor.h
+++ b/ffi/include/tvm/ffi/reflection/accessor.h
@@ -57,11 +57,25 @@ inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view 
type_key, const char
  */
 class FieldGetter {
  public:
+  /*!
+   * \brief Constructor
+   * \param field_info The field info.
+   */
   explicit FieldGetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {}
 
+  /*!
+   * \brief Constructor
+   * \param type_key The type key.
+   * \param field_name The name of the field.
+   */
   explicit FieldGetter(std::string_view type_key, const char* field_name)
       : FieldGetter(GetFieldInfo(type_key, field_name)) {}
 
+  /*!
+   * \brief Get the value of the field
+   * \param obj_ptr The object pointer.
+   * \return The value of the field.
+   */
   Any operator()(const Object* obj_ptr) const {
     Any result;
     const void* addr = reinterpret_cast<const char*>(obj_ptr) + 
field_info_->offset;
@@ -83,11 +97,25 @@ class FieldGetter {
  */
 class FieldSetter {
  public:
+  /*!
+   * \brief Constructor
+   * \param field_info The field info.
+   */
   explicit FieldSetter(const TVMFFIFieldInfo* field_info) : 
field_info_(field_info) {}
 
+  /*!
+   * \brief Constructor
+   * \param type_key The type key.
+   * \param field_name The name of the field.
+   */
   explicit FieldSetter(std::string_view type_key, const char* field_name)
       : FieldSetter(GetFieldInfo(type_key, field_name)) {}
 
+  /*!
+   * \brief Set the value of the field
+   * \param obj_ptr The object pointer.
+   * \param value The value to be set.
+   */
   void operator()(const Object* obj_ptr, AnyView value) const {
     const void* addr = reinterpret_cast<const char*>(obj_ptr) + 
field_info_->offset;
     TVM_FFI_CHECK_SAFE_CALL(
@@ -104,8 +132,15 @@ class FieldSetter {
   const TVMFFIFieldInfo* field_info_;
 };
 
+/*!
+ * \brief Helper class to get type attribute column.
+ */
 class TypeAttrColumn {
  public:
+  /*!
+   * \brief Constructor
+   * \param attr_name The name of the type attribute.
+   */
   explicit TypeAttrColumn(std::string_view attr_name) {
     TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()};
     column_ = TVMFFIGetTypeAttrColumn(&attr_name_array);
@@ -113,7 +148,11 @@ class TypeAttrColumn {
       TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << 
attr_name;
     }
   }
-
+  /*!
+   * \brief Get the type attribute column by type index.
+   * \param type_index The type index.
+   * \return The type attribute column.
+   */
   AnyView operator[](int32_t type_index) const {
     size_t tindex = static_cast<size_t>(type_index);
     if (tindex >= column_->size) {
diff --git a/ffi/include/tvm/ffi/reflection/creator.h 
b/ffi/include/tvm/ffi/reflection/creator.h
index 983b8034a3..774eb8b0b4 100644
--- a/ffi/include/tvm/ffi/reflection/creator.h
+++ b/ffi/include/tvm/ffi/reflection/creator.h
@@ -36,9 +36,17 @@ namespace reflection {
  */
 class ObjectCreator {
  public:
+  /*!
+   * \brief Constructor
+   * \param type_key The type key.
+   */
   explicit ObjectCreator(std::string_view type_key)
       : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {}
 
+  /*!
+   * \brief Constructor
+   * \param type_info The type info.
+   */
   explicit ObjectCreator(const TVMFFITypeInfo* type_info) : 
type_info_(type_info) {
     int32_t type_index = type_info->type_index;
     if (type_info->metadata == nullptr) {
diff --git a/ffi/include/tvm/ffi/reflection/registry.h 
b/ffi/include/tvm/ffi/reflection/registry.h
index 107a6e7759..ba723fa394 100644
--- a/ffi/include/tvm/ffi/reflection/registry.h
+++ b/ffi/include/tvm/ffi/reflection/registry.h
@@ -36,7 +36,10 @@ namespace ffi {
 /*! \brief Reflection namespace */
 namespace reflection {
 
-/*! \brief Trait that can be used to set field info */
+/*!
+ * \brief Trait that can be used to set field info
+ * \sa DefaultValue, AttachFieldFlag
+ */
 struct FieldInfoTrait {};
 
 /*!
@@ -44,8 +47,16 @@ struct FieldInfoTrait {};
  */
 class DefaultValue : public FieldInfoTrait {
  public:
+  /*!
+   * \brief Constructor
+   * \param value The value to be set
+   */
   explicit DefaultValue(Any value) : value_(value) {}
 
+  /*!
+   * \brief Apply the default value to the field info
+   * \param info The field info.
+   */
   TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
     info->default_value = AnyView(value_).CopyToTVMFFIAny();
     info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
@@ -55,7 +66,7 @@ class DefaultValue : public FieldInfoTrait {
   Any value_;
 };
 
-/*
+/*!
  * \brief Trait that can be used to attach field flag
  */
 class AttachFieldFlag : public FieldInfoTrait {
@@ -82,6 +93,10 @@ class AttachFieldFlag : public FieldInfoTrait {
     return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore);
   }
 
+  /*!
+   * \brief Apply the field flag to the field info
+   * \param info The field info.
+   */
   TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= 
flag_; }
 
  private:
@@ -104,6 +119,7 @@ TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T 
Class::*field_ptr) {
   return field_offset_to_class - 
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
 }
 
+/// \cond Doxygen_Suppress
 class ReflectionDefBase {
  protected:
   template <typename T>
@@ -203,10 +219,19 @@ class ReflectionDefBase {
     return ffi::Function::FromTyped(std::forward<Func>(func), name);
   }
 };
+/// \endcond
 
+/*!
+ * \brief GlobalDef helper to register a global function.
+ *
+ * \code
+ *  namespace refl = tvm::ffi::reflection;
+ *  refl::GlobalDef().def("my_ffi_extension.my_function", MyFunction);
+ * \endcode
+ */
 class GlobalDef : public ReflectionDefBase {
  public:
-  /*
+  /*!
    * \brief Define a global function.
    *
    * \tparam Func The function type.
@@ -214,7 +239,7 @@ class GlobalDef : public ReflectionDefBase {
    *
    * \param name The name of the function.
    * \param func The function to be registered.
-   * \param extra The extra arguments that can be docstring.
+   * \param extra The extra arguments that can be docstring or subclass of 
FieldInfoTrait.
    *
    * \return The reflection definition.
    */
@@ -225,7 +250,7 @@ class GlobalDef : public ReflectionDefBase {
     return *this;
   }
 
-  /*
+  /*!
    * \brief Define a global function in ffi::PackedArgs format.
    *
    * \tparam Func The function type.
@@ -233,7 +258,7 @@ class GlobalDef : public ReflectionDefBase {
    *
    * \param name The name of the function.
    * \param func The function to be registered.
-   * \param extra The extra arguments that can be docstring.
+   * \param extra The extra arguments that can be docstring or subclass of 
FieldInfoTrait.
    *
    * \return The reflection definition.
    */
@@ -243,7 +268,7 @@ class GlobalDef : public ReflectionDefBase {
     return *this;
   }
 
-  /*
+  /*!
    * \brief Expose a class method as a global function.
    *
    * An argument will be added to the first position if the function is not 
static.
@@ -253,6 +278,7 @@ class GlobalDef : public ReflectionDefBase {
    *
    * \param name The name of the method.
    * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
    *
    * \return The reflection definition.
    */
@@ -279,9 +305,23 @@ class GlobalDef : public ReflectionDefBase {
   }
 };
 
+/*!
+ * \brief Helper to register Object's reflection metadata.
+ * \tparam Class The class type.
+ *
+ * \code
+ *  namespace refl = tvm::ffi::reflection;
+ *  refl::ObjectDef<MyClass>().def_ro("my_field", &MyClass::my_field);
+ * \endcode
+ */
 template <typename Class>
 class ObjectDef : public ReflectionDefBase {
  public:
+  /*!
+   * \brief Constructor
+   * \tparam ExtraArgs The extra arguments.
+   * \param extra_args The extra arguments.
+   */
   template <typename... ExtraArgs>
   explicit ObjectDef(ExtraArgs&&... extra_args)
       : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), 
type_key_(Class::_type_key) {
@@ -430,14 +470,30 @@ class ObjectDef : public ReflectionDefBase {
   const char* type_key_;
 };
 
+/*!
+ * \brief Helper to register type attribute.
+ * \tparam Class The class type.
+ * \tparam ExtraArgs The extra arguments.
+ *
+ * \code
+ *  namespace refl = tvm::ffi::reflection;
+ *  refl::TypeAttrDef<MyClass>().def("func_attr", MyFunc);
+ * \endcode
+ *
+ */
 template <typename Class, typename = 
std::enable_if_t<std::is_base_of_v<Object, Class>>>
 class TypeAttrDef : public ReflectionDefBase {
  public:
+  /*!
+   * \brief Constructor
+   * \tparam ExtraArgs The extra arguments.
+   * \param extra_args The extra arguments.
+   */
   template <typename... ExtraArgs>
   explicit TypeAttrDef(ExtraArgs&&... extra_args)
       : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {}
 
-  /*
+  /*!
    * \brief Define a function-valued type attribute.
    *
    * \tparam Func The function type.
@@ -457,7 +513,7 @@ class TypeAttrDef : public ReflectionDefBase {
     return *this;
   }
 
-  /*
+  /*!
    * \brief Define a constant-valued type attribute.
    *
    * \tparam T The type of the value.
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index fe84b61547..8da70e5996 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -54,7 +54,7 @@ class BytesObjBase : public Object, public TVMFFIByteArray {};
 
 /*!
  * \brief An object representing bytes.
- * \note We use separate object for bytes to follow python convention
+ * \note We use a separate object for bytes to follow Python convention
  *       and indicate passing of raw bytes.
  *       Bytes can be converted from/to string.
  */
@@ -66,7 +66,7 @@ class BytesObj : public BytesObjBase {
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(BytesObj, Object);
 };
 
-/*! \brief An object representing string. It's POD type. */
+/*! \brief An object representing string. This is a POD type. */
 class StringObj : public BytesObjBase {
  public:
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr;
@@ -257,13 +257,14 @@ class Bytes {
   /*!
    * \brief constructor from size
    *
-   * \param other a char array.
+   * \param data The data pointer.
+   * \param size The size of the char array.
    */
   Bytes(const char* data, size_t size) { this->InitData(data, size); }
   /*!
    * \brief constructor from TVMFFIByteArray
    *
-   * \param other a char array.
+   * \param bytes a char array.
    */
   Bytes(TVMFFIByteArray bytes) {  // NOLINT(*)
     this->InitData(bytes.data, bytes.size);
@@ -391,10 +392,26 @@ class String {
    */
   String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); }
   // constructors from Any
-  String(const String& other) = default;             // NOLINT(*)
-  String(String&& other) = default;                  // NOLINT(*)
+  /*!
+   * \brief Copy constructor
+   * \param other The other string
+   */
+  String(const String& other) = default;  // NOLINT(*)
+  /*!
+   * \brief Move constructor
+   * \param other The other string
+   */
+  String(String&& other) = default;  // NOLINT(*)
+  /*!
+   * \brief Copy assignment operator
+   * \param other The other string
+   */
   String& operator=(const String& other) = default;  // NOLINT(*)
-  String& operator=(String&& other) = default;       // NOLINT(*)
+  /*!
+   * \brief Move assignment operator
+   * \param other The other string
+   */
+  String& operator=(String&& other) = default;  // NOLINT(*)
 
   /*!
    * \brief Swap this String with another string
@@ -404,15 +421,27 @@ class String {
     std::swap(data_, other.data_);
   }
 
+  /*!
+   * \brief Copy assignment operator
+   * \param other The other string
+   */
   String& operator=(const std::string& other) {
     String(other).swap(*this);  // NOLINT(*)
     return *this;
   }
+  /*!
+   * \brief Move assignment operator
+   * \param other The other string
+   */
   String& operator=(std::string&& other) {
     String(std::move(other)).swap(*this);  // NOLINT(*)
     return *this;
   }
 
+  /*!
+   * \brief Copy assignment operator
+   * \param other The other string
+   */
   String& operator=(const char* other) {
     String(other).swap(*this);  // NOLINT(*)
     return *this;
@@ -421,9 +450,10 @@ class String {
   /*!
    * \brief constructor from raw string
    *
-   * \param other a char array.
+   * \param data The data pointer.
+   * \param size The size of the char array.
    */
-  String(const char* other, size_t size) { this->InitData(other, size); }
+  String(const char* data, size_t size) { this->InitData(data, size); }
 
   /*!
    * \brief constructor from raw string
@@ -640,6 +670,7 @@ class String {
 TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) {
   return std::string_view(str.data, str.size);
 }
+/// \cond Doxygen_Suppress
 
 template <>
 inline constexpr bool use_default_type_traits_v<Bytes> = false;
@@ -960,14 +991,14 @@ inline std::ostream& operator<<(std::ostream& out, const 
String& input) {
   out.write(input.data(), input.size());
   return out;
 }
+/// \endcond
 }  // namespace ffi
 
-// Expose to the tvm namespace for usability
-// Rationale: no ambiguity even in root
 using ffi::Bytes;
 using ffi::String;
 }  // namespace tvm
 
+/// \cond Doxygen_Suppress
 namespace std {
 
 template <>
@@ -984,4 +1015,5 @@ struct hash<::tvm::ffi::String> {
   }
 };
 }  // namespace std
+/// \endcond
 #endif  // TVM_FFI_STRING_H_
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index b972f58359..1812448ecc 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -93,8 +93,14 @@ struct TypeTraitsBase {
   }
 };
 
+/*!
+ * \brief Trait that maps a type to its field static type index
+ * \tparam T the type
+ * \return the field static type index
+ */
 template <typename T, typename = void>
 struct TypeToFieldStaticTypeIndex {
+  /*! \brief The field static type index of the type  */
   static constexpr int32_t value = TypeIndex::kTVMFFIAny;
 };
 
@@ -103,8 +109,17 @@ struct TypeToFieldStaticTypeIndex<T, 
std::enable_if_t<TypeTraits<T>::convert_ena
   static constexpr int32_t value = TypeTraits<T>::field_static_type_index;
 };
 
+/*!
+ * \brief Trait that maps a type to its runtime type index
+ * \tparam T the type
+ * \return the runtime type index
+ */
 template <typename T, typename = void>
 struct TypeToRuntimeTypeIndex {
+  /*!
+   * \brief Get the runtime type index of the type
+   * \return the runtime type index
+   */
   static int32_t v() { return TypeToFieldStaticTypeIndex<T>::value; }
 };
 
@@ -161,7 +176,15 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
  */
 class StrictBool {
  public:
+  /*!
+   * \brief Constructor
+   * \param value The value of the strict bool.
+   */
   StrictBool(bool value) : value_(value) {}  // NOLINT(*)
+  /*!
+   *\brief Convert the strict bool to bool.
+   * \return The value of the strict bool.
+   */
   operator bool() const { return value_; }
 
  private:
@@ -582,6 +605,7 @@ struct TypeTraits<TObjRef, 
std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef
 template <typename T, typename... FallbackTypes>
 struct FallbackOnlyTraitsBase : public TypeTraitsBase {
   // disable container for FallbackOnlyTraitsBase
+  /// \cond Doxygen_Suppress
   static constexpr bool storage_enabled = false;
 
   TVM_FFI_INLINE static std::optional<T> TryCastFromAnyView(const TVMFFIAny* 
src) {
@@ -601,6 +625,7 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase {
     }
     return std::nullopt;
   }
+  /// \endcond
 };
 
 /*!
@@ -616,6 +641,7 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase {
  */
 template <typename ObjectRefType, typename... FallbackTypes>
 struct ObjectRefWithFallbackTraitsBase : public 
ObjectRefTypeTraitsBase<ObjectRefType> {
+  /// \cond Doxygen_Suppress
   TVM_FFI_INLINE static std::optional<ObjectRefType> TryCastFromAnyView(const 
TVMFFIAny* src) {
     if (auto opt_obj = 
ObjectRefTypeTraitsBase<ObjectRefType>::TryCastFromAnyView(src)) {
       return *opt_obj;
@@ -637,6 +663,7 @@ struct ObjectRefWithFallbackTraitsBase : public 
ObjectRefTypeTraitsBase<ObjectRe
     }
     return std::nullopt;
   }
+  /// \endcond
 };
 
 // Traits for weak pointer of object
diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py 
b/python/tvm/tir/tensor_intrin/riscv_cpu.py
index febddc2bf3..e0782ada4c 100644
--- a/python/tvm/tir/tensor_intrin/riscv_cpu.py
+++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py
@@ -18,7 +18,8 @@
 """Intrinsics for RISCV tensorization"""
 
 import logging
-from tvm.ffi import register_func
+import tvm_ffi
+
 from tvm.runtime import DataType
 from tvm.script import tir as T
 from tvm.target.codegen import llvm_get_vector_width, target_has_features, 
Target
@@ -165,7 +166,7 @@ def rvv_vec_dot_product_kernels(
     return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl
 
 
-@register_func("tir.tensor_intrin.register_rvv_isa_intrinsics")
+@tvm_ffi.register_global_func("tir.tensor_intrin.register_rvv_isa_intrinsics")
 def register_rvv_isa_intrinsics(target: Target, inventory_only=False) -> 
dict():
     """Register RISCV V (vector) intrinsics
     [x] Implementation follows version 1.0 vector specifications:
diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc 
b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
index 1761f7f2dc..37ae2b4041 100644
--- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
+++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
@@ -202,7 +202,7 @@ class CUDAIPCMemoryAllocator final : public 
memory::PooledAllocator {
  * \return The allocated storage object with internal CUDA IPC memory buffer.
  */
 memory::Storage IPCAllocStorage(ffi::Shape buffer_shape, DLDataType 
dtype_hint) {
-  auto storage_obj = 
ffi::SimpleObjAllocator().make_object<memory::StorageObj>();
+  auto storage_obj = ffi::make_object<memory::StorageObj>();
   nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get();
   Device device{DLDeviceType::kDLCUDA, nccl_ctx->device_id};
   CUDAIPCMemoryAllocator* allocator = CUDAIPCMemoryAllocator::Global();

Reply via email to