This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab2b2d08e9 [FFI][DOCS] Initial docs scaffolding (#18263)
ab2b2d08e9 is described below

commit ab2b2d08e9a804dec33a384d17235653605317f6
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 1 17:34:24 2025 -0400

    [FFI][DOCS] Initial docs scaffolding (#18263)
---
 ffi/docs/.gitignore                                |   1 +
 .../get_started/run_example.sh => docs/Makefile}   |  25 +-
 ffi/docs/README.md                                 |  35 ++
 ffi/docs/concepts/abi_overview.md                  | 430 +++++++++++++++
 ffi/docs/conf.py                                   | 182 +++++++
 ffi/docs/get_started/install.md                    |  83 +++
 ffi/docs/get_started/quick_start.md                | 212 ++++++++
 ffi/docs/guides/cpp_guide.md                       | 584 +++++++++++++++++++++
 ffi/docs/guides/packaging.md                       | 282 ++++++++++
 ffi/docs/guides/python_guide.md                    | 243 +++++++++
 ffi/docs/index.rst                                 |  41 ++
 ffi/docs/requirements.txt                          |  18 +
 ffi/examples/packaging/CMakeLists.txt              |  20 +-
 ffi/examples/packaging/README.md                   |   6 +-
 ffi/examples/packaging/pyproject.toml              |   6 +-
 .../__init__.py                                    |   0
 .../_ffi_api.py                                    |   4 +-
 .../base.py                                        |   6 +-
 ffi/examples/packaging/run_example.py              |   6 +-
 ffi/examples/packaging/src/extension.cc            |   8 +-
 .../quick_start/{get_started => }/CMakeLists.txt   |   0
 .../quick_start/{get_started => }/README.md        |   0
 .../quick_start/{get_started => }/run_example.py   |   0
 .../quick_start/{get_started => }/run_example.sh   |   0
 .../{get_started => }/src/add_one_cpu.cc           |   0
 .../{get_started => }/src/add_one_cuda.cu          |   0
 .../{get_started => }/src/run_example.cc           |   0
 ffi/src/ffi/extra/testing.cc                       |  35 ++
 ffi/tests/cpp/test_example.cc                      | 289 ++++++++++
 29 files changed, 2480 insertions(+), 36 deletions(-)

diff --git a/ffi/docs/.gitignore b/ffi/docs/.gitignore
new file mode 100644
index 0000000000..e35d8850c9
--- /dev/null
+++ b/ffi/docs/.gitignore
@@ -0,0 +1 @@
+_build
diff --git a/ffi/examples/quick_start/get_started/run_example.sh 
b/ffi/docs/Makefile
old mode 100755
new mode 100644
similarity index 53%
copy from ffi/examples/quick_start/get_started/run_example.sh
copy to ffi/docs/Makefile
index 0602b85f37..f589272b18
--- a/ffi/examples/quick_start/get_started/run_example.sh
+++ b/ffi/docs/Makefile
@@ -14,14 +14,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#!/bin/bash
-set -ex
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= python3 -m sphinx
+SOURCEDIR     = .
+BUILDDIR      = _build
 
-cmake -B build -S .
-cmake --build build
+# Put it first so that "make" without argument is like "make help".
+help:
+       @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 
-# running python example
-python run_example.py
+.PHONY: help Makefile livehtml
 
-# running c++ example
-./build/run_example
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+       @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+livehtml:
+       @sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/ffi/docs/README.md b/ffi/docs/README.md
new file mode 100644
index 0000000000..cf96b6f6d4
--- /dev/null
+++ b/ffi/docs/README.md
@@ -0,0 +1,35 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# TVM FFI Documentation
+
+To build locally
+
+First install the tvm-ffi package
+```bash
+pip install ..
+```
+
+Install all the requirements to build docs
+
+```bash
+pip install -r requirements.txt
+```
+
+Then build the doc
+```bash
+make livehtml
+```
diff --git a/ffi/docs/concepts/abi_overview.md 
b/ffi/docs/concepts/abi_overview.md
new file mode 100644
index 0000000000..6d2fd10074
--- /dev/null
+++ b/ffi/docs/concepts/abi_overview.md
@@ -0,0 +1,430 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# ABI Overview
+
+This section provides an overview of the ABI convention of TVM FFI. The ABI
+is designed around the following key principles:
+
+- **Stable C ABI:** Core ABI is defined on top of a stable C ABI.
+- **Minimal and efficient:** Keep things simple when possible and bring 
close-to-metal efficiency.
+- **Focus on machine learning systems:** while also ensuring reasonable 
extensibility.
+
+To explain the concepts in the following sections, we will write in 
**low-level C/C++ code** when possible,
+so the code itself illustrates the low-level semantics of how to work with the 
ABI convention.
+These can serve as references for how to build language bindings and compiler 
codegen for the ABI.
+
+```{note}
+The authoritative ABI specifications are defined in 
[tvm/ffi/c_api.h](https://github.com/apache/tvm/blob/main/ffi/include/tvm/ffi/c_api.h)
 for core ABI,
+and 
[tvm/ffi/extra/c_env_api.h](https://github.com/apache/tvm/blob/main/ffi/include/tvm/ffi/extra/c_env_api.h)
 for extra support features
+such as stream handling. This document provides explanations about design 
concepts and rationales.
+```
+
+## Simplified Example
+
+Before diving into details, it is helpful to review at a high level
+what happens when a function is called in TVM FFI ABI.
+One main design goal here is to represent all kinds of functions in a single
+unified C signature. Please review the following
+simplified code example that illustrates the key idea:
+
+```c++
+// simplified struct for TVMFFIAny
+typedef struct TVMFFIAny {
+  int32_t type_index;
+  uint32_t zero_padding;
+  // union values
+  union {
+    int64_t v_int64;       // integers
+    double v_float64;      // floating-point numbers
+    const char* v_c_str;   // raw C-string
+  };
+};
+
+// This is the signature of TVM FFI function ABI
+typedef int (*TVMFFISafeCallType)(
+   void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result
+);
+
+// An example function signature
+int MyFunc(const char* param0, int param1);
+
+// This is what MyFunc looks like when exposed through TVM FFI ABI
+int MyFuncTVMFFISafeCall(
+  void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result
+) {
+  assert(args[0].type_index == kTVMFFIRawStr);
+  assert(args[1].type_index == kTVMFFInt);
+  result->type_index = kTVMFFInt;
+  result->v_int64 = MyFunc(args[0].v_c_str, args[1].v_int64);
+  // return value indicates no error occurred
+  return 0;
+}
+
+// This is how we call the MyFuncTVMFFISafeCall
+// this can happen on the caller side in another language (e.g. python)
+int CallTVMFFISafeCall(const char* param0, int param1) {
+  // arguments on stack
+  TVMFFIAny args[2], result;
+  args[0].type_index = kTVMFFIRawStr;
+  args[0].v_c_str = param0;
+  args[1].type_index = kTVMFFInt;
+  args[1].v_int64 = param1;
+  result.type_index = kTVMFFINone;
+  // In this case we do not need handle
+  // handle is used to hold closure pointers
+  void* handle = nullptr;
+  int num_args = 2;
+  MyFuncTVMFFISafeCall(handle, args, num_args, &result);
+  return result.v_int64;
+}
+```
+
+At a high level, the `TVMFFISafeCallType` signature does the following things:
+- Arguments and return values are stored in structured `TVMFFIAny`
+  - Each value comes with a `type_index` to indicate its type
+  - Values are stored in union fields, depending on the specific type.
+- Caller can explicitly store the type index and value into
+  a stack of `TVMFFIAny`.
+- Callee can load the parameters from args and check their type indices.
+
+In this way, the same `TVMFFISafeCallType` can be used to represent any 
function
+that contains an arbitrary number of arguments and types that can be 
identified by `type_index`.
+Of course, this is a simplified example and we did not touch on specific 
details
+like Any value format and error handling. The following sections will provide 
a more systematic
+treatment of each of these specific topics.
+You can keep this example in mind as the overall picture and refine it as you 
read through
+the following sections.
+
+
+## TVMFFIAny Storage Format
+
+To start with, we need a mechanism to store the values that are passed across 
machine learning frameworks.
+It achieves this using a core data structure called TVMFFIAny.
+
+```c++
+typedef struct TVMFFIAny {
+  int32_t type_index;
+  union {  // 4 bytes
+    uint32_t zero_padding;
+    uint32_t small_str_len;
+  };
+  // union values
+  union {
+    int64_t v_int64;       // integers
+    double v_float64;      // floating-point numbers
+    void* v_ptr;           // typeless pointers
+    const char* v_c_str;   // raw C-string
+    TVMFFIObject* v_obj;   // ref counted objects
+    DLDataType v_dtype;    // data type
+    DLDevice v_device;     // device
+    char v_bytes[8];       // small string
+    ...
+  };
+} TVMFFIAny;
+```
+
+TVMFFIAny is a 16-byte C structure that follows the design principle of 
tagged-union:
+
+- `type_index` helps us identify the type being stored.
+- The value union part is designed to store the value:
+  - Small POD values (like integers and floats) are stored directly as 
"on-stack" values.
+  - `v_obj` can also point to a managed heap-allocated object, which we will 
discuss next.
+- The second field stores metadata for small strings.
+
+
+### Storing a POD Value
+
+There are many values that are plain-old-data types. In such cases, we store 
them directly
+on-stack in the value part of the TVMFFIAny. The following example shows how 
to store
+an int.
+
+```c++
+void SetIntValue(TVMFFIAny* any, int value) {
+  // must zero the entire space first
+  any->type_index = kTVMFFIInt;
+  any->zero_padding = 0;
+  any->v_int64 = value;
+}
+```
+
+:::{note}
+
+We **must zero the content that is not being used** by
+the current value type. The following example shows a common place
+where mistakes can be made when we forget to zero the value field
+on 32-bit platforms (where pointers only fill the 32-bit part of the value).
+
+```c++
+void SetOpaquePtrValue(TVMFFIAny* any, void* opaque_ptr) {
+  any->type_index = kTVMFFIOpaquePtr;
+  // must zero the padding
+  any->zero_padding = 0;
+  // the zeroing is needed for 32-bit platforms!
+  any->v_uint64 = 0;
+  any->v_ptr = opaque_ptr;
+}
+```
+
+**Rationale:** Such invariants allow us to directly compare
+and hash TVMFFIAny in bytes for quick equality checks without going through
+type index switching.
+:::
+
+
+## Object Storage Format
+
+When TVMFFIAny points to a heap-allocated object (such as n-dimensional 
arrays),
+we adopt a unified object storage format, defined as follows:
+
+```c++
+typedef struct TVMFFIObject {
+  int32_t type_index;
+  uint32_t weak_ref_count;
+  uint64_t strong_ref_count;
+  union {
+    void (*deleter)(struct TVMFFIObject* self, int flags);
+    int64_t __ensure_align;
+  };
+} TVMFFIObject;
+```
+
+`TVMFFIObject` defines a common 24-byte intrusive header that all in-memory 
objects share:
+
+- `type_index` helps us identify the type being stored, which is consistent 
with `TVMFFIAny.type_index`.
+- `weak_ref_count` stores the weak atomic reference counter of the object.
+- `strong_ref_count` stores the strong atomic reference counter of the object.
+- `deleter` should be called when either the strong or weak ref counter goes 
to zero.
+  - The flags are set to indicate the event of either weak or strong going to 
zero, or both.
+  - When `strong_ref_count` gets to zero, the deleter needs to call the 
destructor of the object.
+  - When `weak_ref_count` gets to zero, the deleter needs to free the memory 
allocated by self.
+
+**Rationales:** There are several considerations when designing the data 
structure:
+- `type_index` enables runtime dynamic type checking and casting.
+- We introduce weak/strong ref counters so we can be compatible with systems 
that need weak pointers.
+- The weak ref counter is kept as 32-bit so we can pack the object header as 
24 bytes.
+- `deleter` ensures that objects allocated from one language/runtime can be 
safely deleted in another.
+
+The object format provides a unified way to manage object life-cycle and 
dynamic type casting
+for heap-allocated objects, including Shape, NDArray,
+Function, Array, Map and other custom objects.
+
+
+### DLPack Compatible NDArray
+
+We provide first-class support for DLPack raw unmanaged pointer support as 
well as a managed NDArray object that
+directly adopts the DLPack DLTensor layout. The overall layout of the NDArray 
object is as follows:
+
+```c++
+struct NDArrayObj: public ffi::Object, public DLTensor {
+};
+```
+
+That means we can read out the array buffer information from an `TVMFFIAny`
+in the following way:
+
+```c++
+DLTensor* ReadDLTensorPtr(const TVMFFIAny *value) {
+  if (value->type_index == kTVMFFIDLTensorPtr) {
+    return static_cast<DLTensor*>(value->v_ptr);
+  }
+  assert(value->type_index == kTVMFFINDArray);
+  return reinterpret_cast<DLTensor*>(
+    reinterpret_cast<char*>(value->v_obj) + sizeof(TVMFFIObject));
+}
+```
+The above code can be used as a reference to implement compiler codegen for 
data.
+Note that the C++ API automatically handles such conversion.
+
+### Advanced: Dynamic Type Index
+
+The `TVMFFITypeIndex` defines a set of type indices. Each built-in type has a 
corresponding statically
+assigned type index that is defined in the enum. Static type indices should be 
sufficient for most
+library use cases.
+For advanced use cases we also support user-defined objects whose `type_index` 
are assigned at startup time
+by calling `TVMFFITypeGetOrAllocIndex` with a unique
+`type_key` string. This design allows us to enable decentralized extension of 
the objects as long as the `type_key`
+values are unique by appending namespace prefix to the key.
+
+## AnyView and Managed Any
+
+An `TVMFFIAny` can either be treated as a strongly managed value 
(corresponding to `ffi::Any` in C++),
+or an unmanaged value (corresponding to `ffi::AnyView` in C++).
+- For POD types, there is no difference between the two
+- For object types, copying of AnyView should not change reference counters, 
while copying and deletion
+  of managed Any should result in increase and decrease of strong reference 
counters.
+- When we convert AnyView to Any, we will convert raw C string `const char*` 
and `const TVMFFIByteArray*`
+  into their managed counterparts (String and Bytes).
+- C API function `TVMFFIAnyViewToOwnedAny` is provided to perform such 
conversion.
+
+Unless the user is writing a compiler backend that needs low-level C style 
access, we encourage use of the
+C++ API to automatically manage conversion and casting between normal types 
and Any. The following code
+shows some example usage of the C++ API.
+
+```c++
+#include <tvm/ffi/any.h>
+
+void AnyExample() {
+  namespace ffi = tvm::ffi;
+  // Here is a managed any
+  ffi::Any value = "hello world";
+  // explicit cast to a specific type
+  ffi::String str_value = value.cast<ffi::String>();
+  // copy int to value
+  value = 1;
+  // copy into a view
+  ffi::AnyView view = value;
+  // cast view back to int
+  std::cout << "Value is " << view.cast<int>() << std::endl;
+}
+```
+
+`ffi::Any` can serve as a container type to hold managed values that can be 
recognized by the TVM FFI system.
+They can be composed with container structures such as `Map<String, Any>`, 
`Array<Any>` to represent various
+broad patterns in APIs that may appear in ML systems.
+
+## Function Calling Convention
+
+As discussed in the overview, we need to consider foreign function calls as 
first-class citizens. We adopt a single standard C function as follows:
+
+```c++
+typedef int (*TVMFFISafeCallType)(
+   void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result
+);
+```
+
+The handle contains the pointer to the function object itself, allowing us to 
support closures. args and num_args describe the input arguments and results 
store the return value. When args and results contain heap-managed objects, we 
expect the caller to own args and result.
+
+```{note}
+Before calling the function, caller must set `result->type_index` to be 
kTVMFFINone, or any type index that do not corresponds
+to an on-heap object.
+
+**Rationale:** Simplifies callee implementation as initial state of result can 
be viewed as managed Any.
+```
+
+We call this approach a packed function, as it provides a single signature to 
represent all functions in a "type-erased" way. It saves the need to declare 
and jit shim for each FFI function call while maintaining reasonable 
efficiency. This mechanism enables the following scenarios:
+- Calling from Dynamic Languages (e.g., Python): we provide a tvm_ffi binding 
that prepares the args based on dynamically examining Python arguments passed 
in.
+- Calling from Static Languages (e.g., C++): For static languages, we can 
leverage C++ templates to directly instantiate the arguments on the stack, 
saving the need for dynamic examination.
+- Dynamic language Callbacks: the signature enables us to easily bring dynamic 
language (Python) callbacks as ffi::Function, as we can take each argument and 
convert to the dynamic values.
+- Efficiency: In practice, we find this approach is sufficient for machine 
learning focused workloads. For example, we can get to microsecond level 
overhead for Python/C++ calls, which is generally similar to overhead for eager 
mode. When both sides of calls are static languages, the overhead will go down 
to tens of nanoseconds. As a side note, although we did not find it necessary, 
the signature still leaves room for link time optimization (LTO), when both 
sides are static languages wit [...]
+
+We support first-class Function objects that allow us to also pass 
function/closures from different places around, enabling cool usages such as 
quick python callback for prototyping, and dynamic Functor creation for 
driver-based kernel launching.
+
+
+## Error Handling
+
+Most TVM FFI C API calls, including `TVMFFISafeCallType` uses the return value 
to
+indicate whether an error happens. When an error happens during a function 
call,
+a non-zero value will be returned. The callee needs also to set the error 
through `TVMFFIErrorSetRaisedFromCStr` or `TVMFFIErrorSetRaised` API, which 
stores
+the error on a thread-local storage.
+
+```c++
+// Example function that raises an error
+int ErrorFunc(void* handle, const TVMFFIAny* args, int num_args, TVMFFIAny 
*result) {
+  const char* error_kind = "RuntimeError";
+  const char* error_msg = "error message";
+  // set the thread-local error state
+  TVMFFIErrorSetRaisedFromCStr(error_kind, error_msg);
+  return -1;
+}
+```
+
+The caller can retrieve the error from thread-local error storage
+using `TVMFFIErrorMoveFromRaised` function.
+The ABI stores Error also as a specific Object,
+the overall error object is stored as follows
+```c++
+typedef struct {
+  /*! \brief The kind of the error. */
+  TVMFFIByteArray kind;
+  /*! \brief The message of the error. */
+  TVMFFIByteArray message;
+  /*! \brief The traceback of the error. */
+  TVMFFIByteArray traceback;
+  /*!
+   * \brief Function handle to update the traceback of the error.
+   * \param self The self object handle.
+   * \param traceback The traceback to update.
+   */
+  void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* 
traceback);
+} TVMFFIErrorCell;
+
+// error object
+class ErrorObj : public ffi::Object, public TVMFFIErrorCell {
+};
+```
+
+The error object stores kind, message and traceback as string. When possible,
+we store the traceback in the same format of python traceback (see an example 
as follows):
+```
+File "src/extension.cc", line 45, in void 
my_ffi_extension::RaiseError(tvm::ffi::String)
+```
+
+We provide C++ object `ffi::Error` that can be throwed as exception in c++ 
environment. When we encounter
+the C ABI boundary, we will catch the error and call `TVMFFIErrorSetRaised` to 
propagate the error
+to the caller safely.
+`TVMFFIErrorSetRaisedFromCStr` is a convenient method to set error directly 
from C string and can be useful in compiler backend construction to implement 
features such as assert.
+
+**Rationales:** The error object contains minimal but sufficient information 
to reconstruct structured
+error in python side. We opt-for thread-local error state as it simplifies 
overall support.
+
+## String and Bytes
+
+The ABI supports strings and bytes as first-class citizens. A string can take 
multiple forms that are identified by
+its `type_index`.
+
+- `kTVMFFIRawStr`: raw C string terminated by `\0`.
+- `kTVMFFISmallStr`: small string, the length is stored in `small_str_len` and 
data is stored in `v_bytes`.
+- `kTVMFFIStr`: on-heap string object for strings that are longer than 7 
characters.
+
+The following code shows the layout of the on-heap string object.
+```c++
+// span-like data structure to store header and length
+typedef struct {
+  const char* data;
+  size_t size;
+} TVMFFIByteArray;
+
+// showcase the layout of the on-heap string.
+class StringObj : public ffi::Object, public TVMFFIByteArray {
+};
+```
+
+The following code shows how to read a string from `TVMFFIAny`
+```c++
+TVMFFIByteArray ReadString(const TVMFFIAny *value) {
+  TVMFFIByteArray ret;
+  if (value->type_index == kTVMFFIRawStr) {
+    ret.data = value->v_c_str;
+    ret.size = strlen(ret.data);
+  } else if (value->type_index == kTVMFFISmallStr) {
+    ret.data = value->v_bytes;
+    ret.size = value->small_str_len;
+  } else {
+    assert(value->type_index == kTVMFFIStr);
+    ret = *reinterpret_cast<TVMFFIByteArray*>(
+      reinterpret_cast<char*>(value->v_obj) + sizeof(TVMFFIObject));
+  }
+  return ret;
+}
+```
+
+Similarly, we have type indices to represent bytes. The C++ API provides 
classes
+`ffi::String` and `ffi::Bytes` to enable the automatic conversion of these 
values with Any storage format.
+
+**Rationales:** Separate string and bytes enable clear mappings from the 
Python side. Small string allows us to
+store short names on-stack. To favor 8-byte alignment (v_bytes) and keep 
things simple, we did not further
+pack characters into the `small_len` field.
diff --git a/ffi/docs/conf.py b/ffi/docs/conf.py
new file mode 100644
index 0000000000..64239487c0
--- /dev/null
+++ b/ffi/docs/conf.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# -*- coding: utf-8 -*-
+import os
+import sys
+
+import tomli
+
+# -- General configuration ------------------------------------------------
+
+# Load version from pyproject.toml
+with open("../pyproject.toml", "rb") as f:
+    pyproject_data = tomli.load(f)
+__version__ = pyproject_data["project"]["version"]
+
+project = "tvm-ffi"
+
+version = __version__
+release = __version__
+
+# -- Extensions and extension configurations --------------------------------
+
+extensions = [
+    "myst_parser",
+    "nbsphinx",
+    "autodocsumm",
+    "sphinx.ext.autodoc",
+    "sphinx.ext.autosectionlabel",
+    "sphinx.ext.autosummary",
+    "sphinx.ext.intersphinx",
+    "sphinx.ext.mathjax",
+    "sphinx.ext.napoleon",
+    "sphinx.ext.viewcode",
+    "sphinx_copybutton",
+    "sphinx_reredirects",
+    "sphinx_tabs.tabs",
+    "sphinx_toolbox.collapse",
+    "sphinxcontrib.httpdomain",
+    "sphinxcontrib.mermaid",
+]
+
+nbsphinx_allow_errors = True
+nbsphinx_execute = "never"
+
+autosectionlabel_prefix_document = True
+nbsphinx_allow_directives = True
+
+myst_enable_extensions = [
+    "dollarmath",
+    "amsmath",
+    "deflist",
+    "colon_fence",
+    "html_image",
+    "linkify",
+    "substitution",
+]
+
+myst_heading_anchors = 3
+myst_ref_domains = ["std", "py"]
+myst_all_links_external = False
+
+intersphinx_mapping = {
+    "python": ("https://docs.python.org/3.12";, None),
+    "typing_extensions": 
("https://typing-extensions.readthedocs.io/en/latest";, None),
+    "pillow": ("https://pillow.readthedocs.io/en/stable";, None),
+    "numpy": ("https://numpy.org/doc/stable";, None),
+    "torch": ("https://pytorch.org/docs/stable";, None),
+}
+
+autodoc_mock_imports = ["torch"]
+autodoc_default_options = {
+    "members": True,
+    "undoc-members": True,
+    "show-inheritance": True,
+    "inherited-members": False,
+    "member-order": "bysource",
+}
+
+# -- Other Options --------------------------------------------------------
+
+templates_path = []
+
+redirects = {}
+
+source_suffix = {".rst": "restructuredtext", ".md": "markdown"}
+
+language = "en"
+
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md"]
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = "sphinx"
+
+# A list of ignored prefixes for module index sorting.
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = False
+
+# -- Options for HTML output ----------------------------------------------
+
+html_theme = "sphinx_book_theme"
+html_title = project
+html_copy_source = True
+html_last_updated_fmt = ""
+
+footer_dropdown = {
+    "name": "ASF",
+    "items": [
+        ("ASF Homepage", "https://apache.org/";),
+        ("License", "https://www.apache.org/licenses/";),
+        ("Sponsorship", "https://www.apache.org/foundation/sponsorship.html";),
+        ("Security", "https://tvm.apache.org/docs/reference/security.html";),
+        ("Thanks", "https://www.apache.org/foundation/thanks.html";),
+        ("Events", "https://www.apache.org/events/current-event";),
+    ],
+}
+
+
+footer_copyright = "Copyright © 2025, Apache Software Foundation"
+footer_note = (
+    "Apache TVM, Apache, the Apache feather, and the Apache TVM project "
+    + "logo are either trademarks or registered trademarks of the Apache 
Software Foundation."
+)
+
+
+def footer_html():
+    # Create footer HTML with two-line layout
+    # Generate dropdown menu items
+    dropdown_items = ""
+    for item_name, item_url in footer_dropdown["items"]:
+        dropdown_items += f'<li><a class="dropdown-item" href="{item_url}" 
target="_blank" style="font-size: 0.9em;">{item_name}</a></li>\n'
+
+    footer_dropdown_html = f"""
+  <div class="footer-container" style="margin: 5px 0; font-size: 0.9em; color: 
#6c757d;">
+      <div class="footer-line1" style="display: flex; justify-content: 
space-between; align-items: center; margin-bottom: 3px;">
+          <div class="footer-copyright-short">
+              {footer_copyright}
+          </div>
+          <div class="footer-dropdown">
+              <div class="dropdown">
+                  <button class="btn btn-link dropdown-toggle" type="button" 
id="footerDropdown" data-bs-toggle="dropdown"
+                  aria-expanded="false" style="font-size: 0.9em; color: 
#6c757d; text-decoration: none; padding: 0; border: none; background: none;">
+                      {footer_dropdown['name']}
+                  </button>
+                  <ul class="dropdown-menu" aria-labelledby="footerDropdown" 
style="font-size: 0.9em;">
+{dropdown_items}                  </ul>
+              </div>
+          </div>
+      </div>
+      <div class="footer-line2" style="font-size: 0.9em; color: #6c757d;">
+          {footer_note}
+      </div>
+  </div>
+  """
+    return footer_dropdown_html
+
+
+html_theme_options = {
+    "repository_url": "https://github.com/apache/tvm";,
+    "use_repository_button": True,
+    "extra_footer": footer_html(),
+}
+
+html_context = {
+    "display_github": True,
+    "github_user": "apache",
+    "github_version": "main",
+    "conf_py_path": "/ffi/docs/",
+}
diff --git a/ffi/docs/get_started/install.md b/ffi/docs/get_started/install.md
new file mode 100644
index 0000000000..87223d0114
--- /dev/null
+++ b/ffi/docs/get_started/install.md
@@ -0,0 +1,83 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# Installation
+
+TVM FFI is built and tested on Windows, macOS, and various
+Linux distributions. You can install tvm-ffi using one of the
+methods below
+
+## Quick Start
+
+The easiest way to try it out is to install from PyPI.
+
+```bash
+pip install apache-tvm-ffi
+```
+
+After installation, you can run the following command to confirm that
+the installation was successful
+
+```bash
+tvm-ffi-config -h
+```
+
+This configuration tool is also useful in various ways to help you build
+libraries with tvm-ffi.
+
+
+## Install From Source
+
+You can also build and install tvm-ffi from source.
+
+### Dependencies
+
+- CMake (>= 3.24.0)
+- Git
+- A recent C++ compiler supporting C++17, at minimum:
+    - GCC 7.1
+    - Clang 5.0
+    - Apple Clang 9.3
+    - Visual Studio 2019 (v16.7)
+- Python (>= 3.9)
+
+
+Developers can clone the source repository from GitHub.
+
+```bash
+git clone --recursive https://github.com/apache/tvm tvm
+```
+
+```{note}
+It's important to use the ``--recursive`` flag when cloning the repository, 
which will
+automatically clone the submodules. If you forget to use this flag, you can 
manually clone the submodules
+by running ``git submodule update --init --recursive`` in the root directory.
+```
+
+Then you can install directly in development mode
+
+```bash
+cd tvm/ffi
+pip install -ve .
+```
+
+The additional `-e` flag will install the Python files in `editable` mode,
+which allows direct editing of the Python files to be immediately reflected in 
the package
+and is useful for development.
+
+## What to Do Next
+
+Now that you have installed TVM FFI, we recommend reading the [Quick 
Start](./quick_start.md) tutorial.
diff --git a/ffi/docs/get_started/quick_start.md 
b/ffi/docs/get_started/quick_start.md
new file mode 100644
index 0000000000..1f6b25ef6d
--- /dev/null
+++ b/ffi/docs/get_started/quick_start.md
@@ -0,0 +1,212 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# Quick Start
+
+This is a quick start guide explaining the basic features and usage of tvm-ffi.
+The source code can be found at `examples/quick_start` in the project source.
+
+## Build and Run the Example
+
+Let us first get started by build and run the example. The example will show 
us:
+
+- How to expose c++ functions as tvm ffi ABI function
+- How to load and run tvm-ffi based library from python
+- How to load and run tvm-ffi based library from c++
+
+
+Before starting, ensure you have:
+
+- TVM FFI installed following [installation](./install.md)
+- C++ compiler with C++17 support
+- CMake 3.18 or later
+- (Optional) CUDA toolkit for GPU examples
+- (Optional) PyTorch for checking torch integrations
+
+Then obtain a copy of the tvm-ffi source code.
+
+```bash
+git clone https://github.com/apache/tvm --recursive
+cd tvm/ffi
+```
+
+The examples are now in the example folder, you can quickly build
+the example using the following command.
+```bash
+cd examples/quick_start
+cmake -B build -S .
+cmake --build build
+```
+
+After the build finishes, you can run the python examples by
+```
+python run_example.py
+```
+
+You can also run the c++ example
+
+```
+./build/example
+```
+
+## Walk through the Example
+
+Now we have quickly try things out. Let us now walk through the details of the 
example.
+Specifically, in this example, we create a simple "add one" operation that 
adds 1 to each element of an input
+tensor and expose that function as TVM FFI compatible function. The key file 
structures are as follows:
+
+```
+examples/quick_start/
+├── src/
+│   ├── add_one_cpu.cc      # CPU implementation
+│   ├── add_one_cuda.cu     # CUDA implementation
+│   └── run_example.cc      # C++ usage example
+├── run_example.py          # Python usage example
+├── run_example.sh          # Build and run script
+└── CMakeLists.txt          # Build configuration
+```
+
+### CPU Implementation
+
+```cpp
+#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/function.h>
+
+namespace tvm_ffi_example {
+
+void AddOne(DLTensor* x, DLTensor* y) {
+  // Validate inputs
+  TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+  DLDataType f32_dtype{kDLFloat, 32, 1};
+  TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
+  TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+  TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
+  TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same 
shape";
+
+  // Perform the computation
+  for (int i = 0; i < x->shape[0]; ++i) {
+    static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
+  }
+}
+
+// Expose the function through TVM FFI
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example::AddOne);
+}
+```
+
+**Key Points:**
+- Functions take `DLTensor*` parameters for cross-language compatibility
+- The `TVM_FFI_DLL_EXPORT_TYPED_FUNC` macro exposes the function with a given 
name
+
+### CUDA Implementation
+
+```cpp
+void AddOneCUDA(DLTensor* x, DLTensor* y) {
+  // Validation (same as CPU version)
+  // ...
+
+  int64_t n = x->shape[0];
+  int64_t nthread_per_block = 256;
+  int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
+
+  // Get current CUDA stream from environment
+  cudaStream_t stream = static_cast<cudaStream_t>(
+      TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id));
+
+  // Launch kernel
+  AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(
+      static_cast<float*>(x->data), static_cast<float*>(y->data), n);
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA);
+```
+
+**Key Points:**
+- We use `TVMFFIEnvGetCurrentStream` to obtain the current stream from the 
environement
+- When invoking ffi Function from python end with PyTorch tensor as argument,
+  the stream will be populated with torch's current stream.
+
+
+### Working with PyTorch
+
+Atfer build, we will create library such as `build/add_one_cuda.so`, that can 
be loaded by
+with api `tvm_ffi.load_module`. Then the function will become available as 
property of the loaded module.
+The tensor arguments in the ffi functions automatically consumes torch.Tensor. 
The following code shows how
+to use the function in torch.
+
+```python
+import torch
+import tvm_ffi
+
+if torch.cuda.is_available():
+    mod = tvm_ffi.load_module("build/add_one_cuda.so")
+
+    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
+    y = torch.empty_like(x)
+
+    # TVM FFI automatically handles CUDA streams
+    stream = torch.cuda.Stream()
+    with torch.cuda.stream(stream):
+        mod.add_one_cuda(x, y)
+    stream.synchronize()
+```
+
+### Working with Python Data Arrays
+
+TVM FFI functions works automaticaly with python data arrays that are 
compatible with dlpack.
+The following examples how to use the function with numpy.
+
+```python
+import tvm_ffi
+import numpy as np
+
+# Load the compiled module
+mod = tvm_ffi.load_module("build/add_one_cpu.so")
+
+# Create input and output arrays
+x = np.array([1, 2, 3, 4, 5], dtype=np.float32)
+y = np.empty_like(x)
+
+# Call the function
+mod.add_one_cpu(x, y)
+print("Result:", y)  # [2, 3, 4, 5, 6]
+```
+
+### Working with C++
+
+One important design goal of tvm-ffi is to be universally portable.
+As a result, the result libraries do not have explicit dependencies in python
+and can be loaded in other language environments, such as c++. The following 
code
+shows how to run the example exported function in C++.
+
+```cpp
+#include <tvm/ffi/container/ndarray.h>
+#include <tvm/ffi/extra/module.h>
+
+void CallAddOne(DLTensor* x, DLTensor *y) {
+  namespace ffi = tvm::ffi;
+  ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so");
+  ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
+  add_one_cpu(x, y);
+}
+```
+
+## Summary Key Concepts
+
+- **TVM_FFI_DLL_EXPORT_TYPED_FUNC** exposes a c++ function into tvm-ffi C ABI
+- **DLTensor** is a universal tensor structure that enables zero-copy exchange 
of array data
+- **Module loading** is provided by tvm ffi APIs in multiple languages.
diff --git a/ffi/docs/guides/cpp_guide.md b/ffi/docs/guides/cpp_guide.md
new file mode 100644
index 0000000000..84b6fd8dc9
--- /dev/null
+++ b/ffi/docs/guides/cpp_guide.md
@@ -0,0 +1,584 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# C++ Guide
+
+This guide introduces the tvm-ffi C++ API.
+We provide C++ API on top of the stable C ABI to provide a type-safe and 
efficient way to work with the tvm-ffi.
+The C++ API is designed to abstract away the complexity of the C ABI while 
maintaining full compatibility.
+The C++ API builds around the following key concepts:
+
+- **Any and AnyView**: Type-erased containers that can hold values of any 
supported type in tvm-ffi.
+- **Function**: A type-erased "packed" function that can be invoked like 
normal functions.
+- **Objects and ObjectRefs**: Reference-counted objects to manage on-heap data 
types.
+
+Code examples in this guide use `EXPECT_EQ` for demonstration purposes, which 
is a testing framework macro. In actual applications, you would use standard 
C++ assertions or error handling.
+You can find runnable code of the examples under tests/cpp/test_example.cc.
+
+## Any and AnyView
+
+`Any` and `AnyView` are the foundation of tvm-ffi, providing
+ways to store values that are compatible with the ffi system.
+The following example shows how we can interact with Any and AnyView.
+
+```cpp
+
+#include <tvm/ffi/any.h>
+
+void ExampleAny() {
+  namespace ffi = tvm::ffi;
+  // Create an Any from various types
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  ffi::Any int_value = 42;
+  ffi::Any float_value = 3.14;
+  ffi::Any string_value = "hello world";
+
+  // AnyView provides a lightweight view without ownership
+  ffi::AnyView view = int_value;
+  // we can cast Any/AnyView to a specific type
+  int extracted = view.cast<int>();
+  EXPECT_EQ(extracted, 42);
+
+  // If we are not sure about the type
+  // we can use as to get an optional value
+  std::optional<int> maybe_int = view.as<int>();
+  if (maybe_int.has_value()) {
+    EXPECT_EQ(maybe_int.value(), 42);
+  }
+  // Try cast is another version that will try to run the type
+  // conversion even if the type does not exactly match
+  std::optional<int> maybe_int_try = view.try_cast<int>();
+  if (maybe_int_try.has_value()) {
+    EXPECT_EQ(maybe_int_try.value(), 42);
+  }
+}
+```
+
+At a high level, we can perform the following operations:
+
+- We can store a value into Any, under the hood, Any will record the type of 
the value by its type_index.
+- We can fetch a value from Any or AnyView using the `cast` function.
+- If we are unsure about the type in Any, we can use `as` or `try_cast` 
function to get an optional value.
+
+Under the hood, Any and AnyView store the value via the ABI convention and 
also manage the reference
+counting correctly when the stored value is an on-heap object.
+
+## Object and ObjectRef
+
+The tvm-ffi object system provides the foundation for all managed, 
reference-counted objects
+in the system. It enables type safety, cross-language compatibility, and 
efficient memory management.
+
+The object system is built around three key classes: Object, ObjectPtr, and 
ObjectRef.
+The `Object` class is the base class of all heap-allocated objects. It 
contains a common header
+that includes the `type_index`, reference counter and deleter for the object.
+Users do not need to explicitly manage these fields as part of the C++ API. 
Instead,
+they are automatically managed through a smart pointer `ObjectPtr` which points
+to a heap-allocated object instance.
+The following code shows an example object and the creation of an `ObjectPtr`:
+
+```cpp
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/memory.h>
+
+class MyIntPairObj : public tvm::ffi::Object {
+ public:
+  int64_t a;
+  int64_t b;
+
+  MyIntPairObj() = default;
+  MyIntPairObj(int64_t a, int64_t b) : a(a), b(b) {}
+
+  // Required: declare type information
+  // to register a dynamic type index through the system
+  static constexpr const char* _type_key = "example.MyIntPair";
+  // This macro registers the class with the FFI system to set up the right 
type index
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(MyIntPairObj, tvm::ffi::Object);
+};
+
+void ExampleObjectPtr() {
+  namespace ffi = tvm::ffi;
+  // make_object automatically sets up the deleter correctly
+  // This function creates a new ObjectPtr with proper memory management
+  // It handles allocation, initialization, and sets up the reference counting 
system
+  ffi::ObjectPtr<MyIntPairObj> obj = ffi::make_object<MyIntPairObj>(100, 200);
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(obj->a, 100);
+  EXPECT_EQ(obj->b, 200);
+}
+```
+
+We typically provide a reference class that wraps the ObjectPtr.
+The `ObjectRef` base class provides the interface and reference counting
+functionality for these wrapper classes.
+```cpp
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/memory.h>
+
+class MyIntPair : public tvm::ffi::ObjectRef {
+ public:
+  // Constructor
+  explicit MyIntPair(int64_t a, int64_t b) {
+    data_ = tvm::ffi::make_object<MyIntPairObj>(a, b);
+  }
+
+  // Required: define object reference methods
+  // This macro provides the necessary methods for ObjectRef functionality
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS(MyIntPair, tvm::ffi::ObjectRef, 
MyIntPairObj);
+};
+
+void ExampleObjectRef() {
+  namespace ffi = tvm::ffi;
+  MyIntPair pair(100, 200);
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(pair->a, 100);
+  EXPECT_EQ(pair->b, 200);
+}
+```
+
+**Note:** The ObjectRef provides a user-friendly interface while ObjectPtr 
handles the low-level memory management.
+The ObjectRef acts as a smart pointer wrapper that automatically manages the 
ObjectPtr lifecycle.
+
+The overall implementation pattern is as follows:
+- **Object Class**: Inherits from `ffi::Object`, stores data and implements 
the core functionality.
+- **ObjectPtr**: Smart pointer that manages the Object lifecycle and reference 
counting.
+- **Ref Class**: Inherits from `ffi::ObjectRef`, provides a user-friendly 
interface and automatic memory management.
+
+This design ensures efficient memory management while providing a clean API 
for users. Once we define an ObjectRef class,
+we can integrate it with the Any, AnyView and Functions.
+
+```cpp
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/any.h>
+
+void ExampleObjectRefAny() {
+  namespace ffi = tvm::ffi;
+  MyIntPair pair(100, 200);
+  ffi::Any any = pair;
+  MyIntPair pair2 = any.cast<MyIntPair>();
+  // Note: EXPECT_EQ is used here for demonstration purposes (testing 
framework)
+  EXPECT_EQ(pair2->a, 100);
+  EXPECT_EQ(pair2->b, 200);
+}
+
+```
+
+Under the hood, ObjectPtr manages the lifecycle of the object through the same 
mechanism as shared pointers. We designed
+the object to be intrusive, which means the reference counter and type index 
metadata are embedded at the header of each object.
+This design allows us to allocate the control block and object memory 
together. As we will see in future sections,
+all of our heap-allocated classes such as Function, on-heap String, Array and 
Map are managed using subclasses of Object,
+and the user-facing classes such as Function are ObjectRefs.
+
+
+We provide a collection of built-in object and reference types, which are 
sufficient for common cases.
+Developers can also bring new object types as shown in the example of this 
section. We provide mechanisms
+to expose these objects to other language bindings such as Python.
+
+
+## Function
+
+The `Function` class provides a type-safe way to create and invoke callable 
objects
+through tvm-ffi ABI convention. We can create a `ffi::Function` from an 
existing typed lambda function.
+
+```cpp
+#include <tvm/ffi/function.h>
+
+void ExampleFunctionFromTyped() {
+  namespace ffi = tvm::ffi;
+  // Create a function from a typed lambda
+  ffi::Function fadd1 = ffi::Function::FromTyped(
+    [](const int a) -> int { return a + 1; }
+  );
+  int b = fadd1(1).cast<int>();
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(b, 2);
+}
+```
+
+Under the hood, tvm-ffi leverages Any and AnyView to create a unified ABI for
+all functions. The following example demonstrates the low-level way of defining
+a "packed" function for the same `fadd1`.
+
+```cpp
+void ExampleFunctionFromPacked() {
+  namespace ffi = tvm::ffi;
+  // Create a function from a typed lambda
+  ffi::Function fadd1 = ffi::Function::FromPacked(
+    [](const ffi::AnyView* args, int32_t num_args, ffi::Any* rv) {
+      // Check that we have exactly one argument
+      TVM_FFI_ICHECK_EQ(num_args, 1);
+      int a = args[0].cast<int>();
+      *rv = a + 1;
+    }
+  );
+  int b = fadd1(1).cast<int>();
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(b, 2);
+}
+```
+
+At a high level, `ffi::Function` implements function calling by the following 
convention:
+- The arguments are passed through an on-stack array of `ffi::AnyView`
+- Return values are passed through `ffi::Any`
+
+Because the return value is `ffi::Any`, we need to explicitly call `cast` to 
convert the return
+value to the desirable type. Importantly, `ffi::Function` itself is a value 
type that is compatible
+with tvm-ffi, which means we can pass it as an argument and return values. The 
following code shows
+an example of passing a function as an argument and applying it inside.
+
+```cpp
+void ExampleFunctionPassFunction() {
+  namespace ffi = tvm::ffi;
+  // Create a function from a typed lambda
+  ffi::Function fapply = ffi::Function::FromTyped(
+      [](const ffi::Function f, ffi::Any param) { return f(param.cast<int>()); 
});
+  ffi::Function fadd1 = ffi::Function::FromTyped(  //
+      [](const int a) -> int { return a + 1; });
+  int b = fapply(fadd1, 2).cast<int>();
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(b, 3);
+}
+```
+
+This pattern is very powerful because we can construct `ffi::Function` not 
only from C++,
+but from any languages that expose to the tvm-ffi ABI. For example, this means 
we can easily call functions
+passed in or registered from Python for quick debugging or other purposes.
+
+
+### Global Function Registry
+
+Besides creating functions locally, tvm-ffi provides a global function 
registry that allows
+functions to be registered and called across different modules and languages.
+The following code shows an example
+
+```cpp
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/registry.h>
+
+void ExampleGlobalFunctionRegistry() {
+  namespace ffi = tvm::ffi;
+  ffi::reflection::GlobalDef().def("xyz.add1", [](const int a) -> int { return 
a + 1; });
+  ffi::Function fadd1 = ffi::Function::GetGlobalRequired("xyz.add1");
+  int b = fadd1(1).cast<int>();
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(b, 2);
+}
+```
+
+You can also access and register global functions from the Python API.
+
+### Exporting as Library Symbol
+
+Besides the API that allows registration of functions into the global table,
+we also provide a macro to export static functions as `TVMFFISafeCallType` 
symbols in a dynamic library.
+
+```c++
+void AddOne(DLTensor* x, DLTensor* y) {
+  // ... implementation omitted ...
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, my_ffi_extension::AddOne);
+```
+
+The new `add_one` takes the signature of `TVMFFISafeCallType` and can be 
wrapped as `ffi::Function`
+through the C++ `ffi::Module` API.
+
+```cpp
+ffi::Module mod = ffi::Module::LoadFromFile("path/to/export_lib.so");
+ffi::Function func = mod->GetFunction("add_one").value();
+```
+
+## Error Handling
+
+We provide a specific `ffi::Error` type that is also made compatible with the 
ffi ABI.
+We also provide a macro `TVM_FFI_THROW` to simplify the error throwing step.
+
+```cpp
+// file: cpp/test_example.cc
+#include <tvm/ffi/error.h>
+
+void FuncThrowError() {
+  namespace ffi = tvm::ffi;
+  TVM_FFI_THROW(TypeError) << "test0";
+}
+
+void ExampleErrorHandling() {
+  namespace ffi = tvm::ffi;
+  try {
+    FuncThrowError();
+  } catch (const ffi::Error& e) {
+    EXPECT_EQ(e.kind(), "TypeError");
+    EXPECT_EQ(e.message(), "test0");
+    std::cout << e.traceback() << std::endl;
+  }
+}
+```
+The structured error class records kind, message and traceback that can be 
mapped to
+Pythonic style error types and tracebacks. The traceback follows the Python 
style,
+tvm-ffi will try to preserve the traceback when possible. In the above example,
+you can see the traceback output as
+```
+... more lines omitted
+File "cpp/test_example.cc", line 106, in ExampleErrorHandling
+File "cpp/test_example.cc", line 100, in void FuncThrowError()
+```
+
+The ffi ABI provides minimal but sufficient mechanisms to propagate these 
errors across
+language boundaries.
+So when we call the function from Python, the Error will be translated into a 
corresponding
+Error type. Similarly, when we call a Python callback from C++, the error will 
be translated
+into the right error kind and message.
+
+
+## NDArray
+
+For many use cases, we do not need to manage the nd-array/Tensor memory.
+In such cases, `DLTensor*` can be used as the function arguments.
+There can be cases for a managed container for multi-dimensional arrays.
+`ffi::NDArray` is a minimal container to provide such support.
+Notably, specific logic of device allocations and array operations are 
non-goals
+of the FFI. Instead, we provide minimal generic API `ffi::NDArray::FromNDAlloc`
+to enable flexible customization of NDArray allocation.
+
+```cpp
+#include <tvm/ffi/container/ndarray.h>
+#include <tvm/ffi/container/shape.h>
+
+struct CPUNDAlloc {
+  void AllocData(DLTensor* tensor) {
+    tensor->data = malloc(tvm::ffi::GetDataSize(*tensor));
+  }
+  void FreeData(DLTensor* tensor) { free(tensor->data); }
+};
+
+void ExampleNDArray() {
+  namespace ffi = tvm::ffi;
+  ffi::Shape shape = {1, 2, 3};
+  DLDataType dtype = {kDLFloat, 32, 1};
+  DLDevice device = {kDLCPU, 0};
+  ffi::NDArray nd = ffi::NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, 
device);
+  // now nd is a managed ndarray
+}
+```
+
+The above example shows how we define `CPUNDAlloc` that customizes `AllocData`
+and `FreeData` behavior. The CPUNDAlloc struct will be kept alive with the 
NDArray object.
+This pattern allows us to implement various NDArray allocations using the same 
API:
+
+- For CUDA allocation, we can change malloc to cudaMalloc
+- For memory-pool based allocation, we can update `CPUNDAlloc` to keep a 
strong reference to the pool,
+  so we can keep memory-pool alive when the array is alive.
+
+**Working with Shapes** As you may have noticed in the example, we have a 
`ffi::Shape` container that is used
+to represent the shapes in nd-array. This container allows us to have compact 
and efficient representation
+of managed shapes and we provide quick conversions from standard vector types.
+
+### DLPack Conversion
+
+We provide first-class DLPack support to the `ffi::NDArray` that enables 
efficient exchange
+through the DLPack Protocol.
+
+```cpp
+#include <tvm/ffi/container/ndarray.h>
+
+void ExampleNDArrayDLPack() {
+  namespace ffi = tvm::ffi;
+  ffi::Shape shape = {1, 2, 3};
+  DLDataType dtype = {kDLFloat, 32, 1};
+  DLDevice device = {kDLCPU, 0};
+  ffi::NDArray nd = ffi::NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, 
device);
+  // convert to DLManagedTensorVersioned
+  DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned();
+  // load back from DLManagedTensorVersioned
+  ffi::NDArray nd2 = ffi::NDArray::FromDLPackVersioned(dlpack);
+}
+```
+
+These APIs are also available through the C APIs
+`TVMFFINDArrayFromDLPackVersioned` and `TVMFFINDArrayToDLPackVersioned`.
+
+## String and Bytes
+
+The tvm-ffi provides first-class support for `String` and `Bytes` types that 
are efficient,
+FFI-compatible, and interoperable with standard C++ string types.
+
+```cpp
+#include <tvm/ffi/string.h>
+
+void ExampleString() {
+  namespace ffi = tvm::ffi;
+  ffi::String str = "hello world";
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(str.size(), 11);
+  std::string std_str = str;
+  EXPECT_EQ(std_str, "hello world");
+}
+```
+
+Alternatively, users can always directly use `std::string` in function 
arguments, conversion
+will happen automatically.
+
+**Rationale:** We need to have separate Bytes and String so they map well to 
corresponding Python types.
+`ffi::String` is backed by a possibly managed object that makes it more 
compatible with the Object system.
+
+## Container Types
+
+To enable effective passing and storing of collections of values that are 
compatible with tvm-ffi,
+we provide several built-in container types.
+
+### Array
+
+`Array<T>` provides an array data type that can be used as function arguments.
+When we use `Array<T>` as an argument of a Function, it will
+perform runtime checks of the elements to ensure the values match the expected 
type.
+
+```cpp
+#include <tvm/ffi/container/array.h>
+
+
+void ExampleArray() {
+  namespace ffi = tvm::ffi;
+  ffi::Array<int> numbers = {1, 2, 3};
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(numbers.size(), 3);
+  EXPECT_EQ(numbers[0], 1);
+
+  ffi::Function head = ffi::Function::FromTyped([](const ffi::Array<int> a) {
+    return a[0];
+  });
+  EXPECT_EQ(head(numbers).cast<int>(), 1);
+
+  try {
+    // throw an error because 2.2 is not int
+    head(ffi::Array<ffi::Any>({1, 2.2}));
+  } catch (const ffi::Error& e) {
+    EXPECT_EQ(e.kind(), "TypeError");
+  }
+}
+```
+
+Under the hood, Array is backed by a reference-counted Object `ArrayObj` that 
stores
+a collection of Any values. Note that conversion from Any to `Array<T>` will 
result in
+runtime checks of elements because the type index only indicates `ArrayObj` as 
the backing storage.
+If you want to defer such checks at the FFI function boundary, consider using 
`Array<Any>` instead.
+When passing lists and tuples from Python, the values will be converted to 
`Array<Any>` before
+being passed into the Function.
+
+**Performance note:** Repeatedly converting Any to `Array<T>` can incur 
repeated
+checking overhead at each element. Consider using `Array<Any>` to defer 
checking or only run conversion once.
+
+### Tuple
+
+`Tuple<Types...>` provides type-safe fixed-size collections.
+
+```cpp
+#include <tvm/ffi/container/tuple.h>
+
+void ExampleTuple() {
+  namespace ffi = tvm::ffi;
+  ffi::Tuple<int, ffi::String, bool> tup(42, "hello", true);
+
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(tup.get<0>(), 42);
+  EXPECT_EQ(tup.get<1>(), "hello");
+  EXPECT_EQ(tup.get<2>(), true);
+}
+```
+
+Under the hood, Tuple is backed by the same `ArrayObj` as the Array container.
+This enables zero-cost exchange with input arguments.
+
+**Rationale:** This design unifies the conversion rules from Python list/tuple 
to
+Array/Tuple. We always need a container representation for tuples
+to be stored in Any.
+
+### Map
+
+`Map<K, V>` provides a key-value based hashmap container that can accept 
dict-style parameters.
+
+```cpp
+#include <tvm/ffi/container/map.h>
+
+void ExampleMap() {
+  namespace ffi = tvm::ffi;
+
+  ffi::Map<ffi::String, int> map0 = {{"Alice", 100}, {"Bob", 95}};
+
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(map0.size(), 2);
+  EXPECT_EQ(map0.at("Alice"), 100);
+  EXPECT_EQ(map0.count("Alice"), 1);
+}
+```
+
+
+Under the hood, Map is backed by a reference-counted Object `MapObj` that 
stores
+a collection of Any values. The implementation provides a SmallMap variant 
that stores
+values as an array and another variant that is based on a hashmap. The Map 
preserves insertion
+order like Python dictionaries. Conversion from Any to `Map<K, V>` will result 
in
+runtime checks of its elements because the type index only indicates `MapObj` 
as the backing storage.
+If you want to defer such checks at the FFI function boundary, consider using 
`Map<Any, Any>` instead.
+When passing dictionaries from Python, the values will be converted to 
`Map<Any, Any>` before
+being passed into the Function.
+
+**Performance note:** Repeatedly converting Any to `Map<K, V>` can incur 
repeated
+checking overhead at each element. Consider using `Map<Any, Any>` to defer 
checking or only run conversion once.
+
+### Optional
+
+`Optional<T>` provides a safe way to handle values that may or may not exist.
+We specialize Optional for `ffi::String` and Object types to be more compact,
+using nullptr to indicate non-existence.
+
+```cpp
+#include <tvm/ffi/container/optional.h>
+
+void ExampleOptional() {
+  namespace ffi = tvm::ffi;
+  ffi::Optional<int> opt0 = 100;
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(opt0.has_value(), true);
+  EXPECT_EQ(opt0.value(), 100);
+
+  ffi::Optional<ffi::String> opt1;
+  EXPECT_EQ(opt1.has_value(), false);
+  EXPECT_EQ(opt1.value_or("default"), "default");
+}
+```
+
+
+### Variant
+
+`Variant<Types...>` provides a type-safe union of different types.
+
+```cpp
+#include <tvm/ffi/container/variant.h>
+
+void ExampleVariant() {
+  namespace ffi = tvm::ffi;
+  ffi::Variant<int, ffi::String> var0 = 100;
+  // EXPECT_EQ is used here for demonstration purposes (testing framework)
+  EXPECT_EQ(var0.get<int>(), 100);
+
+  var0 = ffi::String("hello");
+  std::optional<ffi::String> maybe_str = var0.as<ffi::String>();
+  EXPECT_EQ(maybe_str.value(), "hello");
+
+  std::optional<int> maybe_int2 = var0.as<int>();
+  EXPECT_EQ(maybe_int2.has_value(), false);
+}
+```
+
+Under the hood, Variant is a wrapper around Any that restricts the type to the 
specific types in the list.
diff --git a/ffi/docs/guides/packaging.md b/ffi/docs/guides/packaging.md
new file mode 100644
index 0000000000..544a45e52d
--- /dev/null
+++ b/ffi/docs/guides/packaging.md
@@ -0,0 +1,282 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# Packaging
+
+This guide explains how to package a tvm-ffi-based library into a Python 
ABI-agnostic wheel.
+It demonstrates both source-level builds (for cross-compilation) and builds 
based on pre-shipped shared libraries.
+At a high level, packaging with tvm-ffi offers several benefits:
+
+- **ABI-agnostic wheels**: Works across different Python versions with minimal 
dependency.
+- **Universally deployable**: Build once with tvm-ffi and ship to different 
environments, including Python and non-Python environments.
+
+While this guide shows how to build a wheel package, the resulting 
`my_ffi_extension.so` is agnostic
+to Python, comes with minimal dependencies, and can be used in other 
deployment scenarios.
+
+## Build and Run the Example
+
+Let's start by building and running the example.
+First, obtain a copy of the tvm-ffi source code.
+
+```bash
+git clone https://github.com/apache/tvm --recursive
+cd tvm/ffi
+```
+
+The examples are now in the examples folder. You can quickly build
+and install the example using the following command.
+```bash
+cd examples/packaging
+pip install -v .
+```
+
+Then you can run examples that leverage the built wheel package.
+
+```bash
+python run_example.py add_one
+```
+
+## Setup pyproject.toml
+
+A typical tvm-ffi-based project has the following structure:
+
+```
+├── CMakeLists.txt          # CMake build configuration
+├── pyproject.toml          # Python packaging configuration
+├── src/
+│   └── extension.cc        # C++ source code
+├── python/
+│   └── my_ffi_extension/
+│       ├── __init__.py     # Python package initialization
+│       ├── base.py         # Library loading logic
+│       └── _ffi_api.py     # FFI API registration
+└── README.md               # Project documentation
+```
+
+The `pyproject.toml` file configures the build system and project metadata.
+
+```toml
+[project]
+name = "my-ffi-extension"
+version = "0.1.0"
+# ... more project metadata omitted ...
+
+[build-system]
+requires = ["scikit-build-core>=0.10.0", "apache-tvm-ffi"]
+build-backend = "scikit_build_core.build"
+
+[tool.scikit-build]
+# ABI-agnostic wheel
+wheel.py-api = "py3"
+# ... more build configuration omitted ...
+```
+
+We use scikit-build-core for building the wheel. Make sure you add tvm-ffi as 
a build-system requirement.
+Importantly, we should set `wheel.py-api` to `py3` to indicate it is 
ABI-generic.
+
+## Setup CMakeLists.txt
+
+The CMakeLists.txt handles the build and linking of the project.
+There are two ways you can build with tvm-ffi:
+
+- Link the pre-built `libtvm_ffi` shipped from the pip package
+- Build tvm-ffi from source
+
+For common cases, using the pre-built library and linking tvm_ffi_shared is 
sufficient.
+To build with the pre-built library, you can do:
+
+```cmake
+cmake_minimum_required(VERSION 3.18)
+project(my_ffi_extension)
+
+find_package(Python COMPONENTS Interpreter REQUIRED)
+execute_process(
+  COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
+  OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
+# find the prebuilt package
+find_package(tvm_ffi CONFIG REQUIRED)
+
+# ... more cmake configuration omitted ...
+
+# linking the library
+target_link_libraries(my_ffi_extension tvm_ffi_shared)
+```
+
+There are cases where one may want to cross-compile or bundle part of tvm_ffi 
objects directly
+into the project. In such cases, you should build from source.
+
+```cmake
+execute_process(
+  COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --sourcedir
+  OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
+# add the shipped source code as a cmake subdirectory
+add_subdirectory(${tvm_ffi_ROOT} tvm_ffi)
+
+# ... more cmake configuration omitted ...
+
+# linking the library
+target_link_libraries(my_ffi_extension tvm_ffi_shared)
+```
+Note that it is always safe to build from source, and the extra cost of 
building tvm-ffi is small
+because tvm-ffi is a lightweight library. If you are in doubt,
+you can always choose to build tvm-ffi from source.
+In Python or other cases when we dynamically load libtvm_ffi shipped with the 
dedicated pip package,
+you do not need to ship libtvm_ffi.so in your package even if you build 
tvm-ffi from source.
+The built objects are only used to supply the linking information.
+
+## Exposing C++ Functions
+
+The C++ implementation is defined in `src/extension.cc`.
+There are two ways one can expose a function in C++ to the FFI library.
+First, `TVM_FFI_DLL_EXPORT_TYPED_FUNC` can be used to expose the function 
directly as a C symbol that follows the tvm-ffi ABI,
+which can later be accessed via `tvm_ffi.load_module`.
+
+Here's a basic example of the function implementation:
+
+```c++
+void AddOne(DLTensor* x, DLTensor* y) {
+  // ... implementation omitted ...
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, my_ffi_extension::AddOne);
+```
+
+We can also register a function into the global function table with a given 
name:
+
+```c++
+void RaiseError(ffi::String msg) {
+  TVM_FFI_THROW(RuntimeError) << msg;
+}
+
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef()
+    .def("my_ffi_extension.raise_error", RaiseError);
+});
+```
+
+Make sure to have a unique name across all registered functions when 
registering a global function.
+Always prefix with a package namespace name to avoid name collisions.
+The function can then be found via `tvm_ffi.get_global_func(name)`
+and is expected to stay throughout the lifetime of the program.
+
+We recommend using `TVM_FFI_DLL_EXPORT_TYPED_FUNC` for functions that are 
supposed to be dynamically
+loaded (such as JIT scenarios) so they won't be exposed to the global function 
table.
+
+## Library Loading in Python
+
+The base module handles loading the compiled extension:
+
+```python
+import tvm_ffi
+import os
+import sys
+
+def _load_lib():
+    file_dir = os.path.dirname(os.path.realpath(__file__))
+
+    # Platform-specific library names
+    if sys.platform.startswith("win32"):
+        lib_name = "my_ffi_extension.dll"
+    elif sys.platform.startswith("darwin"):
+        lib_name = "my_ffi_extension.dylib"
+    else:
+        lib_name = "my_ffi_extension.so"
+
+    lib_path = os.path.join(file_dir, lib_name)
+    return tvm_ffi.load_module(lib_path)
+
+_LIB = _load_lib()
+```
+
+Effectively, it leverages the `tvm_ffi.load_module` call to load the library
+extension DLL shipped along with the package. The `_ffi_api.py` contains a 
function
+call to `tvm_ffi._init_api` that registers all global functions prefixed
+with `my_ffi_extension` into the module.
+
+```python
+# _ffi_api.py
+import tvm_ffi
+from .base import _LIB
+
+# Register all global functions prefixed with 'my_ffi_extension.'
+# This makes functions registered via TVM_FFI_STATIC_INIT_BLOCK available
+tvm_ffi._init_api("my_ffi_extension", __name__)
+```
+
+Then we can redirect the calls to the related functions.
+
+```python
+from .base import _LIB
+from . import _ffi_api
+
+def add_one(x, y):
+    # ... docstring omitted ...
+    return _LIB.add_one(x, y)
+
+def raise_error(msg):
+    # ... docstring omitted ...
+    return _ffi_api.raise_error(msg)
+```
+
+## Build and Use the Package
+
+First, build the wheel:
+```bash
+pip wheel -v -w dist .
+```
+
+Then install the built wheel:
+```bash
+pip install dist/*.whl
+```
+
+Then you can try it out:
+
+```python
+import torch
+import my_ffi_extension
+
+# Create input and output tensors
+x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+y = torch.empty_like(x)
+
+# Call the function
+my_ffi_extension.add_one(x, y)
+print(y)  # Output: tensor([2., 3., 4., 5., 6.])
+```
+
+You can also run the following command to see how errors are raised and 
propagated
+across language boundaries:
+
+```python
+python run_example.py raise_error
+```
+
+When possible, tvm-ffi will try to preserve tracebacks across language 
boundaries. You will see tracebacks like:
+```
+File "src/extension.cc", line 45, in void 
my_ffi_extension::RaiseError(tvm::ffi::String)
+```
+
+## Wheel Auditing
+
+When using `auditwheel`, exclude `libtvm_ffi` as it will be shipped with the 
`tvm_ffi` package.
+
+```bash
+auditwheel repair --exclude libtvm_ffi.so dist/*.whl
+```
+
+As long as you import `tvm_ffi` first before loading the library, the symbols 
will be available.
diff --git a/ffi/docs/guides/python_guide.md b/ffi/docs/guides/python_guide.md
new file mode 100644
index 0000000000..2d588049ae
--- /dev/null
+++ b/ffi/docs/guides/python_guide.md
@@ -0,0 +1,243 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+# Python Guide
+
+This guide introduces the `tvm_ffi` Python package.
+At a high level, the `tvm_ffi` Python package provides first-class Python 
support for
+
+- Pythonic classes to represent values in TVM FFI Any ABI.
+- Mechanisms to call into TVM FFI ABI compatible functions.
+- Conversion between Python values and `tvm_ffi` values.
+
+In this guide, we will run examples that make use of pre-registered testing 
functions in `tvm_ffi`.
+If so, we will also briefly copy snippets that show the corresponding C++ 
behavior.
+
+## Load and Run Module
+
+The most common use case of TVM FFI is to load a runnable module and run the 
corresponding function.
+You can follow the [quick start guide](../get_started/quick_start.md) for 
details on building the
+library `build/add_one_cpu.so`. Let's walk through the load and run example 
again for NumPy
+
+```python
+import tvm_ffi
+import numpy as np
+
+# Load the compiled module
+mod = tvm_ffi.load_module("build/add_one_cpu.so")
+
+# Create input and output arrays
+x = np.array([1, 2, 3, 4, 5], dtype=np.float32)
+y = np.empty_like(x)
+
+# Call the function
+mod.add_one_cpu(x, y)
+```
+
+In this case, `tvm_ffi.load_module` will return a `tvm_ffi.Module` class that 
contains
+the exported functions. You can access the functions by their names.
+
+## NDArray
+
+`tvm_ffi` provides a managed DLPack-compatible NDArray.
+
+```python
+import numpy as np
+import tvm_ffi
+
+# Demonstrate DLPack conversion between NumPy and TVM FFI
+np_data = np.array([1, 2, 3, 4], dtype=np.float32)
+tvm_array = tvm_ffi.from_dlpack(np_data)
+# Convert back to NumPy
+np_result = np.from_dlpack(tvm_array)
+```
+
+In most cases, however, you do not have to explicitly create NDArrays.
+The Python interface can take in `torch.Tensor` and `numpy.ndarray` objects
+and automatically convert them to `tvm_ffi.NDArray`.
+
+## Functions and Callbacks
+
+`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the 
C++.
+You can retrieve globally registered functions via `tvm_ffi.get_global_func()`.
+
+```python
+import tvm_ffi
+
+# testing.echo is defined and registered in C++
+# [](ffi::Any x) { return x; }
+fecho = tvm_ffi.get_global_func("testing.echo")
+assert fecho(1) == 1
+```
+
+You can pass a Python function as an argument to another FFI function as 
callbacks.
+Under the hood, `tvm_ffi.convert` is called to convert the Python function 
into a
+`tvm_ffi.Function`.
+
+```python
+import tvm_ffi
+
+# testing.apply is registered in C++
+# [](ffi::Function f, ffi::Any val) { return f(x); }
+fapply = tvm_ffi.get_global_func("testing.apply")
+# invoke fapply with lambda callback as f
+assert fapply(lambda x: x + 1, 1) == 2
+```
+
+This is a very powerful pattern that allows us to inject Python callbacks into 
the C++ code.
+You can also register a Python callback as a global function.
+
+```python
+import tvm_ffi
+
+@tvm_ffi.register_func("example.add_one")
+def add_one(a):
+    return a + 1
+
+assert tvm_ffi.get_global_func("example.add_one")(1) == 2
+```
+
+## Container Types
+
+When an FFI function takes arguments from lists/tuples, they will be converted 
into `tvm_ffi.Array`.
+
+```python
+import tvm_ffi
+
+# Lists become Arrays
+arr = tvm_ffi.convert([1, 2, 3, 4])
+assert isinstance(arr, tvm_ffi.Array)
+assert len(arr) == 4
+assert arr[0] == 1
+```
+
+Dictionaries will be converted to `tvm_ffi.Map`
+
+```python
+import tvm_ffi
+
+map_obj = tvm_ffi.convert({"a": 1, "b": 2})
+assert isinstance(map_obj, tvm_ffi.Map)
+assert len(map_obj) == 2
+assert map_obj["a"] == 1
+assert map_obj["b"] == 2
+```
+
+When container values are returned from FFI functions, they are also stored in 
these
+types respectively.
+
+
+## Error Handling
+
+An FFI function may raise an error. In such cases, the Python package will 
automatically
+translate the error to the corresponding error kind in Python
+
+```python
+import tvm_ffi
+
+# defined in C++
+# [](String kind, String msg) { throw Error(kind, msg, traceback); }
+test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error")
+
+test_raise_error("ValueError", "message")
+```
+The above code shows an example where an error is raised in C++, resulting in 
the following error trace
+```
+Traceback (most recent call last):
+File "example.py", line 7, in <module>
+  test_raise_error("ValueError", "message")
+  ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
+File "python/tvm_ffi/cython/function.pxi", line 325, in core.Function.__call__
+  raise move_from_last_error().py_error()
+  ^^^
+File "src/ffi/extra/testing.cc", line 60, in void 
tvm::ffi::TestRaiseError(tvm::ffi::String, tvm::ffi::String)
+  throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, 
TVM_FFI_FUNC_SIG, 0));
+```
+
+We register common error kinds. You can also register extra error dispatch via 
the `tvm_ffi.register_error` function.
+
+## Advanced: Register Your Own Object
+
+For advanced use cases, you may want to register your own objects. This can be 
achieved through the
+reflection registry in the TVM-FFI API. First, let's review the C++ side of 
the code. For this
+example, you do not need to change the C++ side as this code is pre-shipped 
with the testing module of the `tvm_ffi` package.
+
+```cpp
+#include <tvm/ffi/reflection/registry.h>
+
+// Step 1: Define the object class (stores the actual data)
+class TestIntPairObj : public tvm::ffi::Object {
+public:
+  int64_t a;
+  int64_t b;
+
+  TestIntPairObj() = default;
+  TestIntPairObj(int64_t a, int64_t b) : a(a), b(b) {}
+
+  // Required: declare type information
+  static constexpr const char* _type_key = "testing.TestIntPair";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestIntPairObj, tvm::ffi::Object);
+};
+
+// Step 2: Define the reference wrapper (user-facing interface)
+class TestIntPair : public tvm::ffi::ObjectRef {
+public:
+  // Constructor
+  explicit TestIntPair(int64_t a, int64_t b) {
+    data_ = tvm::ffi::make_object<TestIntPairObj>(a, b);
+  }
+
+  // Required: define object reference methods
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS(TestIntPair, tvm::ffi::ObjectRef, 
TestIntPairObj);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  // register the object into the system
+  // register field accessors and a global static function `__create__` as 
ffi::Function
+  refl::ObjectDef<TestIntPairObj>()
+    .def_ro("a", &TestIntPairObj::a)
+    .def_ro("b", &TestIntPairObj::b)
+    .def_static("__create__", [](int64_t a, int64_t b) -> TestIntPair {
+      return TestIntPair(a, b);
+    });
+});
+```
+
+You can then create wrapper classes for objects that are in the library as 
follows:
+
+```python
+import tvm_ffi
+
+# Register the class
+@tvm_ffi.register_object("testing.TestIntPair")
+class TestIntPair(tvm_ffi.Object):
+    def __init__(self, a, b):
+        # This is a special method to call an FFI function whose return
+        # value exactly initializes the object handle of the object
+        self.__init_handle_by_constructor__(TestIntPair.__create__, a, b)
+
+test_int_pair = TestIntPair(1, 2)
+# We can access the fields by name
+# The properties are populated by the reflection mechanism
+assert test_int_pair.a == 1
+assert test_int_pair.b == 2
+```
+Under the hood, we leverage the information registered through the reflection 
registry to
+generate efficient field accessors and methods for each class.
+
+Importantly, when you have multiple inheritance, you need to call 
`tvm_ffi.register_object`
+on both the base class and the child class.
diff --git a/ffi/docs/index.rst b/ffi/docs/index.rst
new file mode 100644
index 0000000000..c3f0b3ea51
--- /dev/null
+++ b/ffi/docs/index.rst
@@ -0,0 +1,41 @@
+..  Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+..    http://www.apache.org/licenses/LICENSE-2.0
+
+..  Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Apache TVM FFI Documentation
+============================
+
+.. toctree::
+   :maxdepth: 1
+   :caption: Get Started
+
+   get_started/install.md
+   get_started/quick_start.md
+
+.. toctree::
+   :maxdepth: 1
+   :caption: Guides
+
+   guides/packaging.md
+   guides/cpp_guide.md
+   guides/python_guide.md
+
+
+.. toctree::
+   :maxdepth: 1
+   :caption: Concepts
+
+   concepts/abi_overview.md
diff --git a/ffi/docs/requirements.txt b/ffi/docs/requirements.txt
new file mode 100644
index 0000000000..b7be6f6d62
--- /dev/null
+++ b/ffi/docs/requirements.txt
@@ -0,0 +1,18 @@
+autodocsumm
+matplotlib
+myst-parser
+nbconvert
+nbsphinx
+nbstripout
+sphinx
+sphinx-autobuild
+sphinx-book-theme
+sphinx-copybutton
+sphinx-reredirects==0.1.2
+sphinx-tabs == 3.4.1
+sphinx-toolbox == 3.4.0
+sphinxcontrib-mermaid
+sphinxcontrib-napoleon==0.7
+sphinxcontrib_httpdomain==1.8.1
+tomli
+urllib3>=2.5.0
diff --git a/ffi/examples/packaging/CMakeLists.txt 
b/ffi/examples/packaging/CMakeLists.txt
index 47e5040a0d..ed55f7ca33 100644
--- a/ffi/examples/packaging/CMakeLists.txt
+++ b/ffi/examples/packaging/CMakeLists.txt
@@ -16,7 +16,7 @@
 # under the License.
 
 cmake_minimum_required(VERSION 3.18)
-project(tvm_ffi_extension)
+project(my_ffi_extension)
 
 option(TVM_FFI_EXT_FROM_SOURCE "Build tvm_ffi from source, useful for cross 
compilation." ON)
 option(TVM_FFI_EXT_SHIP_DEBUG_SYMBOLS "Ship debug symbols" ON)
@@ -35,7 +35,7 @@ option(TVM_FFI_EXT_SHIP_DEBUG_SYMBOLS "Ship debug symbols" ON)
 #   So when in doubt, you can always choose to the building tvm_ffi from 
source route.
 #
 # In python or other cases when we dynamically load libtvm_ffi_shared. Even 
when you build
-# from source, you do not need to ship libtvm_ffi_shared.so built here as they 
are only
+# from source, you do not need to ship libtvm_ffi.so built here as they are 
only
 # used to supply the linking information.
 # first find python related components
 find_package(Python COMPONENTS Interpreter REQUIRED)
@@ -54,20 +54,20 @@ else()
 endif()
 
 # use the projects as usual
-add_library(tvm_ffi_extension SHARED src/extension.cc)
-target_link_libraries(tvm_ffi_extension tvm_ffi_header)
-target_link_libraries(tvm_ffi_extension tvm_ffi_shared)
+add_library(my_ffi_extension SHARED src/extension.cc)
+target_link_libraries(my_ffi_extension tvm_ffi_header)
+target_link_libraries(my_ffi_extension tvm_ffi_shared)
 
-# show as tvm_ffi_extension.so
+# show as my_ffi_extension.so
 set_target_properties(
-  tvm_ffi_extension PROPERTIES PREFIX ""
+  my_ffi_extension PROPERTIES PREFIX ""
 )
 
 if (TVM_FFI_EXT_SHIP_DEBUG_SYMBOLS)
   # ship debugging symbols for backtrace on macos
-  tvm_ffi_add_prefix_map(tvm_ffi_extension ${CMAKE_CURRENT_SOURCE_DIR})
-  tvm_ffi_add_apple_dsymutil(tvm_ffi_extension)
+  tvm_ffi_add_prefix_map(my_ffi_extension ${CMAKE_CURRENT_SOURCE_DIR})
+  tvm_ffi_add_apple_dsymutil(my_ffi_extension)
   install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ DESTINATION . FILES_MATCHING 
PATTERN "*.dSYM")
 endif()
 
-install(TARGETS tvm_ffi_extension DESTINATION .)
+install(TARGETS my_ffi_extension DESTINATION .)
diff --git a/ffi/examples/packaging/README.md b/ffi/examples/packaging/README.md
index 9535581af6..25bcc1ca3c 100644
--- a/ffi/examples/packaging/README.md
+++ b/ffi/examples/packaging/README.md
@@ -35,11 +35,11 @@ pip install .
 ### Note on build and auditwheel
 
 Note: When running the auditwheel process, make sure to skip
-`libtvm_ffi_shared.so` as they are shipped via the tvm_ffi package.
+`libtvm_ffi.so` as they are shipped via the tvm_ffi package.
 
 ## Run the example
 
-After installing the `tvm_ffi_extension` example package, you can run the 
following example
+After installing the `my_ffi_extension` example package, you can run the 
following example
 that invokes the `add_one` function exposed.
 
 ```bash
@@ -55,7 +55,7 @@ python run_example.py raise_error
 
 When possible, tvm_ffi will try to preserve traceback across language 
boundary. You will see traceback like
 ```
-File "src/extension.cc", line 45, in void 
tvm_ffi_extension::RaiseError(tvm::ffi::String)
+File "src/extension.cc", line 45, in void 
my_ffi_extension::RaiseError(tvm::ffi::String)
 ```
 If you are in an IDE like VSCode, you can click and jump to the C++ lines of 
error when
 the debug symbols are preserved.
diff --git a/ffi/examples/packaging/pyproject.toml 
b/ffi/examples/packaging/pyproject.toml
index e38ebeccff..7825ca81ce 100644
--- a/ffi/examples/packaging/pyproject.toml
+++ b/ffi/examples/packaging/pyproject.toml
@@ -16,7 +16,7 @@
 # under the License.
 
 [project]
-name = "tvm-ffi-extension"
+name = "my-ffi-extension"
 version = "0.1.0"
 
 readme = "README.md"
@@ -54,5 +54,5 @@ cmake.build-type = "RelWithDebugInfo"
 logging.level = "INFO"
 
 # Wheel configuration
-wheel.packages = ["python/tvm_ffi_extension"]
-wheel.install-dir = "tvm_ffi_extension"
+wheel.packages = ["python/my_ffi_extension"]
+wheel.install-dir = "my_ffi_extension"
diff --git a/ffi/examples/packaging/python/tvm_ffi_extension/__init__.py 
b/ffi/examples/packaging/python/my_ffi_extension/__init__.py
similarity index 100%
rename from ffi/examples/packaging/python/tvm_ffi_extension/__init__.py
rename to ffi/examples/packaging/python/my_ffi_extension/__init__.py
diff --git a/ffi/examples/packaging/python/tvm_ffi_extension/_ffi_api.py 
b/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py
similarity index 90%
rename from ffi/examples/packaging/python/tvm_ffi_extension/_ffi_api.py
rename to ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py
index 1ab9abd765..79c269ab0a 100644
--- a/ffi/examples/packaging/python/tvm_ffi_extension/_ffi_api.py
+++ b/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py
@@ -20,5 +20,5 @@ import tvm_ffi
 from .base import _LIB
 
 # this is a short cut to register all the global functions
-# prefixed by `tvm_ffi_extension.` to this module
-tvm_ffi._init_api("tvm_ffi_extension", __name__)
+# prefixed by `my_ffi_extension.` to this module
+tvm_ffi._init_api("my_ffi_extension", __name__)
diff --git a/ffi/examples/packaging/python/tvm_ffi_extension/base.py 
b/ffi/examples/packaging/python/my_ffi_extension/base.py
similarity index 89%
rename from ffi/examples/packaging/python/tvm_ffi_extension/base.py
rename to ffi/examples/packaging/python/my_ffi_extension/base.py
index ed73193770..d65264eb71 100644
--- a/ffi/examples/packaging/python/tvm_ffi_extension/base.py
+++ b/ffi/examples/packaging/python/my_ffi_extension/base.py
@@ -24,11 +24,11 @@ def _load_lib():
     file_dir = os.path.dirname(os.path.realpath(__file__))
 
     if sys.platform.startswith("win32"):
-        lib_dll_name = "tvm_ffi_extension.dll"
+        lib_dll_name = "my_ffi_extension.dll"
     elif sys.platform.startswith("darwin"):
-        lib_dll_name = "tvm_ffi_extension.dylib"
+        lib_dll_name = "my_ffi_extension.dylib"
     else:
-        lib_dll_name = "tvm_ffi_extension.so"
+        lib_dll_name = "my_ffi_extension.so"
 
     lib_path = os.path.join(file_dir, lib_dll_name)
     return tvm_ffi.load_module(lib_path)
diff --git a/ffi/examples/packaging/run_example.py 
b/ffi/examples/packaging/run_example.py
index 88efae20cc..11642257e8 100644
--- a/ffi/examples/packaging/run_example.py
+++ b/ffi/examples/packaging/run_example.py
@@ -16,18 +16,18 @@
 # Base logic to load library for extension package
 import torch
 import sys
-import tvm_ffi_extension
+import my_ffi_extension
 
 
 def run_add_one():
     x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
     y = torch.empty_like(x)
-    tvm_ffi_extension.add_one(x, y)
+    my_ffi_extension.add_one(x, y)
     print(y)
 
 
 def run_raise_error():
-    tvm_ffi_extension.raise_error("This is an error")
+    my_ffi_extension.raise_error("This is an error")
 
 
 if __name__ == "__main__":
diff --git a/ffi/examples/packaging/src/extension.cc 
b/ffi/examples/packaging/src/extension.cc
index 20a1f91fda..eb4be8508d 100644
--- a/ffi/examples/packaging/src/extension.cc
+++ b/ffi/examples/packaging/src/extension.cc
@@ -29,7 +29,7 @@
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 
-namespace tvm_ffi_extension {
+namespace my_ffi_extension {
 
 namespace ffi = tvm::ffi;
 
@@ -57,7 +57,7 @@ void AddOne(DLTensor* x, DLTensor* y) {
 }
 
 // expose global symbol add_one
-TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, tvm_ffi_extension::AddOne);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, my_ffi_extension::AddOne);
 
 // The static initialization block is
 // called once when the library is loaded.
@@ -83,6 +83,6 @@ TVM_FFI_STATIC_INIT_BLOCK({
   // When registering via reflection mechanisms, the library do not need to be 
loaded via
   // tvm::ffi::Module::LoadFromFile, instead, just load the dll or simply 
bundle into the
   // final project
-  refl::GlobalDef().def("tvm_ffi_extension.raise_error", RaiseError);
+  refl::GlobalDef().def("my_ffi_extension.raise_error", RaiseError);
 });
-}  // namespace tvm_ffi_extension
+}  // namespace my_ffi_extension
diff --git a/ffi/examples/quick_start/get_started/CMakeLists.txt 
b/ffi/examples/quick_start/CMakeLists.txt
similarity index 100%
rename from ffi/examples/quick_start/get_started/CMakeLists.txt
rename to ffi/examples/quick_start/CMakeLists.txt
diff --git a/ffi/examples/quick_start/get_started/README.md 
b/ffi/examples/quick_start/README.md
similarity index 100%
rename from ffi/examples/quick_start/get_started/README.md
rename to ffi/examples/quick_start/README.md
diff --git a/ffi/examples/quick_start/get_started/run_example.py 
b/ffi/examples/quick_start/run_example.py
similarity index 100%
rename from ffi/examples/quick_start/get_started/run_example.py
rename to ffi/examples/quick_start/run_example.py
diff --git a/ffi/examples/quick_start/get_started/run_example.sh 
b/ffi/examples/quick_start/run_example.sh
similarity index 100%
rename from ffi/examples/quick_start/get_started/run_example.sh
rename to ffi/examples/quick_start/run_example.sh
diff --git a/ffi/examples/quick_start/get_started/src/add_one_cpu.cc 
b/ffi/examples/quick_start/src/add_one_cpu.cc
similarity index 100%
rename from ffi/examples/quick_start/get_started/src/add_one_cpu.cc
rename to ffi/examples/quick_start/src/add_one_cpu.cc
diff --git a/ffi/examples/quick_start/get_started/src/add_one_cuda.cu 
b/ffi/examples/quick_start/src/add_one_cuda.cu
similarity index 100%
rename from ffi/examples/quick_start/get_started/src/add_one_cuda.cu
rename to ffi/examples/quick_start/src/add_one_cuda.cu
diff --git a/ffi/examples/quick_start/get_started/src/run_example.cc 
b/ffi/examples/quick_start/src/run_example.cc
similarity index 100%
rename from ffi/examples/quick_start/get_started/src/run_example.cc
rename to ffi/examples/quick_start/src/run_example.cc
diff --git a/ffi/src/ffi/extra/testing.cc b/ffi/src/ffi/extra/testing.cc
index 1a7bdb4e68..0800d48795 100644
--- a/ffi/src/ffi/extra/testing.cc
+++ b/ffi/src/ffi/extra/testing.cc
@@ -30,6 +30,41 @@
 namespace tvm {
 namespace ffi {
 
+// Step 1: Define the object class (stores the actual data)
+class TestIntPairObj : public tvm::ffi::Object {
+ public:
+  int64_t a;
+  int64_t b;
+
+  TestIntPairObj() = default;
+  TestIntPairObj(int64_t a, int64_t b) : a(a), b(b) {}
+
+  // Required: declare type information
+  static constexpr const char* _type_key = "testing.TestIntPair";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestIntPairObj, tvm::ffi::Object);
+};
+
+// Step 2: Define the reference wrapper (user-facing interface)
+class TestIntPair : public tvm::ffi::ObjectRef {
+ public:
+  // Constructor
+  explicit TestIntPair(int64_t a, int64_t b) {
+    data_ = tvm::ffi::make_object<TestIntPairObj>(a, b);
+  }
+
+  // Required: define object reference methods
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS(TestIntPair, tvm::ffi::ObjectRef, 
TestIntPairObj);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK({
+  namespace refl = tvm::ffi::reflection;
+  refl::ObjectDef<TestIntPairObj>()
+      .def_ro("a", &TestIntPairObj::a)
+      .def_ro("b", &TestIntPairObj::b)
+      .def_static("__create__",
+                  [](int64_t a, int64_t b) -> TestIntPair { return 
TestIntPair(a, b); });
+});
+
 class TestObjectBase : public Object {
  public:
   int64_t v_i64;
diff --git a/ffi/tests/cpp/test_example.cc b/ffi/tests/cpp/test_example.cc
new file mode 100644
index 0000000000..68e5298219
--- /dev/null
+++ b/ffi/tests/cpp/test_example.cc
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/ndarray.h>
+#include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/registry.h>
+
+// test-cases used in example code
+namespace {
+
+void ExampleAny() {
+  namespace ffi = tvm::ffi;
+  // Create an Any from various types
+  ffi::Any int_value = 42;
+  ffi::Any float_value = 3.14;
+  ffi::Any string_value = "hello world";
+
+  // AnyView provides a lightweight view without ownership
+  ffi::AnyView view = int_value;
+  // we can cast Any/AnyView to a specific type
+  int extracted = view.cast<int>();
+  EXPECT_EQ(extracted, 42);
+
+  // If we are not sure about the type
+  // we can use as to get an optional value
+  std::optional<int> maybe_int = view.as<int>();
+  if (maybe_int.has_value()) {
+    EXPECT_EQ(maybe_int.value(), 42);
+  }
+  // Try cast is another version that will try to run the type
+  // conversion even if the type does not exactly match
+  std::optional<int> maybe_int_try = view.try_cast<int>();
+  if (maybe_int_try.has_value()) {
+    EXPECT_EQ(maybe_int_try.value(), 42);
+  }
+}
+
+TEST(Example, Any) { ExampleAny(); }
+
+void ExampleFunctionFromPacked() {
+  namespace ffi = tvm::ffi;
+  // Create a function from a typed lambda
+  ffi::Function fadd1 =
+      ffi::Function::FromPacked([](const ffi::AnyView* args, int32_t num_args, 
ffi::Any* rv) {
+        TVM_FFI_ICHECK_EQ(num_args, 1);
+        int a = args[0].cast<int>();
+        *rv = a + 1;
+      });
+  int b = fadd1(1).cast<int>();
+  EXPECT_EQ(b, 2);
+}
+
+void ExampleFunctionFromTyped() {
+  namespace ffi = tvm::ffi;
+  // Create a function from a typed lambda
+  ffi::Function fadd1 = ffi::Function::FromTyped([](const int a) -> int { 
return a + 1; });
+  int b = fadd1(1).cast<int>();
+  EXPECT_EQ(b, 2);
+}
+
+void ExampleFunctionPassFunction() {
+  namespace ffi = tvm::ffi;
+  // Create a function from a typed lambda
+  ffi::Function fapply = ffi::Function::FromTyped(
+      [](const ffi::Function f, ffi::Any param) { return f(param.cast<int>()); 
});
+  ffi::Function fadd1 = ffi::Function::FromTyped(  //
+      [](const int a) -> int { return a + 1; });
+  int b = fapply(fadd1, 2).cast<int>();
+  EXPECT_EQ(b, 3);
+}
+
+void ExamplegGlobalFunctionRegistry() {
+  namespace ffi = tvm::ffi;
+  ffi::reflection::GlobalDef().def("xyz.add1", [](const int a) -> int { return 
a + 1; });
+  ffi::Function fadd1 = ffi::Function::GetGlobalRequired("xyz.add1");
+  int b = fadd1(1).cast<int>();
+  EXPECT_EQ(b, 2);
+}
+
+void FuncThrowError() {
+  namespace ffi = tvm::ffi;
+  TVM_FFI_THROW(TypeError) << "test0";
+}
+
+void ExampleErrorHandling() {
+  namespace ffi = tvm::ffi;
+  try {
+    FuncThrowError();
+  } catch (const ffi::Error& e) {
+    EXPECT_EQ(e.kind(), "TypeError");
+    EXPECT_EQ(e.message(), "test0");
+    std::cout << e.traceback() << std::endl;
+  }
+}
+
+TEST(Example, Function) {
+  ExampleFunctionFromPacked();
+  ExampleFunctionFromTyped();
+  ExampleFunctionPassFunction();
+  ExamplegGlobalFunctionRegistry();
+  ExampleErrorHandling();
+}
+
+struct CPUNDAlloc {
+  void AllocData(DLTensor* tensor) { tensor->data = 
malloc(tvm::ffi::GetDataSize(*tensor)); }
+  void FreeData(DLTensor* tensor) { free(tensor->data); }
+};
+
+void ExampleNDArray() {
+  namespace ffi = tvm::ffi;
+  ffi::Shape shape = {1, 2, 3};
+  DLDataType dtype = {kDLFloat, 32, 1};
+  DLDevice device = {kDLCPU, 0};
+  ffi::NDArray nd = ffi::NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, 
device);
+}
+
+void ExampleNDArrayDLPack() {
+  namespace ffi = tvm::ffi;
+  ffi::Shape shape = {1, 2, 3};
+  DLDataType dtype = {kDLFloat, 32, 1};
+  DLDevice device = {kDLCPU, 0};
+  ffi::NDArray nd = ffi::NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, 
device);
+  // convert to DLManagedTensorVersioned
+  DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned();
+  // load back from DLManagedTensorVersioned
+  ffi::NDArray nd2 = ffi::NDArray::FromDLPackVersioned(dlpack);
+}
+
+TEST(Example, NDArray) {
+  ExampleNDArray();
+  ExampleNDArrayDLPack();
+}
+
+void ExampleString() {
+  namespace ffi = tvm::ffi;
+  ffi::String str = "hello world";
+  EXPECT_EQ(str.size(), 11);
+  std::string std_str = str;
+  EXPECT_EQ(std_str, "hello world");
+}
+
+TEST(Example, String) { ExampleString(); }
+
+void ExampleArray() {
+  namespace ffi = tvm::ffi;
+  ffi::Array<int> numbers = {1, 2, 3};
+  EXPECT_EQ(numbers.size(), 3);
+  EXPECT_EQ(numbers[0], 1);
+
+  ffi::Function head = ffi::Function::FromTyped([](const ffi::Array<int> a) { 
return a[0]; });
+  EXPECT_EQ(head(numbers).cast<int>(), 1);
+
+  try {
+    // throw an error because 2.2 is not int
+    head(ffi::Array<ffi::Any>({1, 2.2}));
+  } catch (const ffi::Error& e) {
+    EXPECT_EQ(e.kind(), "TypeError");
+  }
+}
+
+void ExampleTuple() {
+  namespace ffi = tvm::ffi;
+  ffi::Tuple<int, ffi::String, bool> tup(42, "hello", true);
+
+  EXPECT_EQ(tup.get<0>(), 42);
+  EXPECT_EQ(tup.get<1>(), "hello");
+  EXPECT_EQ(tup.get<2>(), true);
+}
+
+TEST(Example, Array) {
+  ExampleArray();
+  ExampleTuple();
+}
+
+void ExampleMap() {
+  namespace ffi = tvm::ffi;
+
+  ffi::Map<ffi::String, int> map0 = {{"Alice", 100}, {"Bob", 95}};
+
+  EXPECT_EQ(map0.size(), 2);
+  EXPECT_EQ(map0.at("Alice"), 100);
+  EXPECT_EQ(map0.count("Alice"), 1);
+}
+
+TEST(Example, Map) { ExampleMap(); }
+
+void ExampleOptional() {
+  namespace ffi = tvm::ffi;
+  ffi::Optional<int> opt0 = 100;
+  EXPECT_EQ(opt0.has_value(), true);
+  EXPECT_EQ(opt0.value(), 100);
+
+  ffi::Optional<ffi::String> opt1;
+  EXPECT_EQ(opt1.has_value(), false);
+  EXPECT_EQ(opt1.value_or("default"), "default");
+}
+
+TEST(Example, Optional) { ExampleOptional(); }
+
+void ExampleVariant() {
+  namespace ffi = tvm::ffi;
+  ffi::Variant<int, ffi::String> var0 = 100;
+  EXPECT_EQ(var0.get<int>(), 100);
+
+  var0 = ffi::String("hello");
+  std::optional<ffi::String> maybe_str = var0.as<ffi::String>();
+  EXPECT_EQ(maybe_str.value(), "hello");
+
+  std::optional<int> maybe_int2 = var0.as<int>();
+  EXPECT_EQ(maybe_int2.has_value(), false);
+}
+
+TEST(Example, Variant) { ExampleVariant(); }
+
+// Step 1: Define the object class (stores the actual data)
+class MyIntPairObj : public tvm::ffi::Object {
+ public:
+  int64_t a;
+  int64_t b;
+
+  MyIntPairObj() = default;
+  MyIntPairObj(int64_t a, int64_t b) : a(a), b(b) {}
+
+  // Required: declare type information
+  static constexpr const char* _type_key = "example.MyIntPair";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(MyIntPairObj, tvm::ffi::Object);
+};
+
+// Step 2: Define the reference wrapper (user-facing interface)
+class MyIntPair : public tvm::ffi::ObjectRef {
+ public:
+  // Constructor
+  explicit MyIntPair(int64_t a, int64_t b) { data_ = 
tvm::ffi::make_object<MyIntPairObj>(a, b); }
+
+  // Required: define object reference methods
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS(MyIntPair, tvm::ffi::ObjectRef, 
MyIntPairObj);
+};
+
+void ExampleObjectPtr() {
+  namespace ffi = tvm::ffi;
+  ffi::ObjectPtr<MyIntPairObj> obj = ffi::make_object<MyIntPairObj>(100, 200);
+  EXPECT_EQ(obj->a, 100);
+  EXPECT_EQ(obj->b, 200);
+}
+
+void ExampleObjectRef() {
+  namespace ffi = tvm::ffi;
+  MyIntPair pair(100, 200);
+  EXPECT_EQ(pair->a, 100);
+  EXPECT_EQ(pair->b, 200);
+}
+
+void ExampleObjectRefAny() {
+  namespace ffi = tvm::ffi;
+  MyIntPair pair(100, 200);
+  ffi::Any any = pair;
+  MyIntPair pair2 = any.cast<MyIntPair>();
+  EXPECT_EQ(pair2->a, 100);
+  EXPECT_EQ(pair2->b, 200);
+}
+
+TEST(Example, ObjectPtr) {
+  ExampleObjectPtr();
+  ExampleObjectRef();
+  ExampleObjectRefAny();
+}
+
+}  // namespace

Reply via email to