This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 3539d70  [FEAT] Recursive DLPack container conversion for auto 
torch.Tensor return (#517)
3539d70 is described below

commit 3539d7077620dbade194bd231091cd4029924b83
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Apr 2 00:18:14 2026 -0700

    [FEAT] Recursive DLPack container conversion for auto torch.Tensor return 
(#517)
    
    ### Problem
    
    When a packed FFI function receives `torch.Tensor` inputs, the return
    value is automatically converted back to `torch.Tensor` via DLPack — but
    only for bare ffi.Tensor returns. When the return is a container (Array,
    List, Map, Dict) containing tensors, the tensors inside remain as
    ffi.Tensor, requiring manual per-element conversion.
    
    ### Solution
    
    We perform a lazy conversion of each element in the container from
    ffi::Tensor to torch::Tensor when they're retrieved. A lazy conversion
    ensures both semantic correctness of containers and reduce runtime
    overhead compared with eager conversion.
    
    ### Stream Propagation
    | Container | Element | Stream set? | How |
    | :--- | :--- | :--- | :--- |
    | list/tuple | torch.Tensor | Yes | ConstructorCall → each element hits
    DLPackExchangeAPI_ setter |
    | list/tuple | ffi.Tensor | No | ConstructorCall → each element hits
    Tensor_ setter (no stream) |
    | dict | torch.Tensor (values) | Yes | Same as above via ConstructorCall
    |
    | dict | ffi.Tensor (values) | No | Same as above |
    | ffi.Array/ffi.List (tagged) | ffi.Tensor | Yes | ContainerObject_
    setter → _scan_seq_for_stream |
    | ffi.Map/ffi.Dict (tagged) | ffi.Tensor | Yes | ContainerObject_ setter
    → _scan_map_for_stream |
    | ffi.Array/ffi.List (untagged) | ffi.Tensor | No | Object_ setter
    (pass-through) |
    | ffi.Map/ffi.Dict (untagged) | ffi.Tensor | No | Object_ setter
    (pass-through) |
---
 python/tvm_ffi/__init__.py                       |   2 +-
 python/tvm_ffi/container.py                      |   8 +-
 python/tvm_ffi/core.pyi                          |   1 +
 python/tvm_ffi/cython/function.pxi               | 208 +++++++++++++++++++++-
 python/tvm_ffi/cython/object.pxi                 |  28 +++
 src/ffi/testing/testing.cc                       |  48 +++++-
 tests/python/test_container.py                   |   8 +-
 tests/python/test_container_dlpack_conversion.py | 210 +++++++++++++++++++++++
 tests/python/test_current_work_stream_gpu.py     |   2 +-
 9 files changed, 503 insertions(+), 12 deletions(-)

diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index aa17624..77e95e5 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -66,7 +66,7 @@ if TYPE_CHECKING or not _is_config_mode():
         init_ffi_api,
     )
     from ._dtype import dtype
-    from .core import Object, ObjectConvertible, Function, CAny
+    from .core import Object, ObjectConvertible, Function, CAny, CContainerBase
     from ._convert import convert
     from .error import register_error
     from ._tensor import Device, device, DLDeviceType
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index de20155..332f817 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -134,7 +134,7 @@ def normalize_index(length: int, idx: SupportsIndex) -> int:
 
 
 @register_object("ffi.Array")
-class Array(core.Object, Sequence[T]):
+class Array(core.CContainerBase, core.Object, Sequence[T]):
     """Array container that represents a sequence of values in the FFI.
 
     :py:func:`tvm_ffi.convert` will map python list/tuple to this class.
@@ -212,7 +212,7 @@ class Array(core.Object, Sequence[T]):
 
 
 @register_object("ffi.List")
-class List(core.Object, MutableSequence[T]):
+class List(core.CContainerBase, core.Object, MutableSequence[T]):
     """Mutable list container that represents a mutable sequence in the FFI."""
 
     # tvm-ffi-stubgen(begin): object/ffi.List
@@ -438,7 +438,7 @@ class ItemsView(ItemsViewBase[K, V]):
 
 
 @register_object("ffi.Map")
-class Map(core.Object, Mapping[K, V]):
+class Map(core.CContainerBase, core.Object, Mapping[K, V]):
     """Map container.
 
     :py:func:`tvm_ffi.convert` will map python dict to this class.
@@ -544,7 +544,7 @@ class Map(core.Object, Mapping[K, V]):
 
 
 @register_object("ffi.Dict")
-class Dict(core.Object, MutableMapping[K, V]):
+class Dict(core.CContainerBase, core.Object, MutableMapping[K, V]):
     """Mutable dictionary container with shared reference semantics.
 
     Unlike :class:`Map`, ``Dict`` does NOT implement copy-on-write.
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 88de53a..f6e8aa0 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -50,6 +50,7 @@ class CObject:
     def _move(self) -> ObjectRValueRef: ...
     def __move_handle_from__(self, other: CObject) -> None: ...
 
+class CContainerBase(CObject): ...
 class Object(CObject): ...
 
 def object_repr(obj: CObject) -> str: ...
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index f66d7df..65ba378 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -62,6 +62,178 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result):
     return bytearray_to_bytes(&bytes)
 
 
+cdef inline bint _check_elem_for_stream(
+    TVMFFIAny* elem_result,
+    const DLPackExchangeAPI* api,
+    TVMFFIPyCallContext* ctx
+) noexcept:
+    """Check a single element for non-CPU tensor; set stream if found.
+
+    Returns True if a non-CPU tensor was found and stream was set.
+    Releases the element ref (for object types) in all cases.
+    """
+    cdef DLTensor* dltensor
+    cdef void* stream = NULL
+    cdef int32_t ti = elem_result.type_index
+
+    if ti == kTVMFFITensor:
+        dltensor = 
TVMFFITensorGetDLTensorPtr(<TVMFFIObjectHandle>elem_result.v_obj)
+        if dltensor.device.device_type != kDLCPU:
+            ctx.device_type = dltensor.device.device_type
+            ctx.device_id = dltensor.device.device_id
+            api.current_work_stream(
+                dltensor.device.device_type,
+                dltensor.device.device_id,
+                &stream)
+            ctx.stream = <TVMFFIStreamHandle>stream
+            TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
+            return True
+        TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
+    elif (ti == kTVMFFIArray or ti == kTVMFFIList
+          or ti == kTVMFFIMap or ti == kTVMFFIDict):
+        _scan_container_for_stream(
+            <TVMFFIObjectHandle>elem_result.v_obj, ti, api, ctx)
+        TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
+        if ctx.device_type != -1:
+            return True
+    elif ti >= kTVMFFIStaticObjectBegin:
+        TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
+    return False
+
+
+cdef inline void _scan_seq_for_stream(
+    TVMFFIObjectHandle chandle,
+    int32_t type_index,
+    const DLPackExchangeAPI* api,
+    TVMFFIPyCallContext* ctx
+) noexcept:
+    """Scan an Array or List for the first non-CPU tensor."""
+    cdef TVMFFIObjectHandle size_func_handle
+    cdef TVMFFIObjectHandle getitem_func_handle
+    cdef TVMFFIAny size_args[1]
+    cdef TVMFFIAny size_result
+    cdef TVMFFIAny getitem_args[2]
+    cdef TVMFFIAny elem_result
+    cdef int64_t n, i
+
+    if type_index == kTVMFFIArray:
+        size_func_handle = (<CObject>_FFI_ARRAY_SIZE).chandle
+        getitem_func_handle = (<CObject>_FFI_ARRAY_GET_ITEM).chandle
+    else:
+        size_func_handle = (<CObject>_FFI_LIST_SIZE).chandle
+        getitem_func_handle = (<CObject>_FFI_LIST_GET_ITEM).chandle
+
+    size_args[0].type_index = type_index
+    size_args[0].v_obj = <TVMFFIObject*>chandle
+    size_result.type_index = kTVMFFINone
+    size_result.v_int64 = 0
+    if TVMFFIFunctionCall(size_func_handle, size_args, 1, &size_result) != 0:
+        return
+
+    n = size_result.v_int64
+    if n == 0:
+        return
+
+    getitem_args[0].type_index = type_index
+    getitem_args[0].v_obj = <TVMFFIObject*>chandle
+
+    for i in range(n):
+        getitem_args[1].type_index = kTVMFFIInt
+        getitem_args[1].v_int64 = i
+        elem_result.type_index = kTVMFFINone
+        elem_result.v_int64 = 0
+        if TVMFFIFunctionCall(getitem_func_handle, getitem_args, 2, 
&elem_result) != 0:
+            return
+        if _check_elem_for_stream(&elem_result, api, ctx):
+            return
+
+
+cdef inline void _scan_map_for_stream(
+    TVMFFIObjectHandle chandle,
+    int32_t type_index,
+    const DLPackExchangeAPI* api,
+    TVMFFIPyCallContext* ctx
+) noexcept:
+    """Scan a Map or Dict's values for the first non-CPU tensor."""
+    cdef TVMFFIObjectHandle size_func_handle
+    cdef TVMFFIObjectHandle iter_func_handle
+    cdef TVMFFIAny size_args[1]
+    cdef TVMFFIAny size_result
+    cdef TVMFFIAny iter_args[1]
+    cdef TVMFFIAny iter_result
+    cdef TVMFFIObjectHandle iter_handle = NULL
+    cdef TVMFFIAny cmd[1]
+    cdef TVMFFIAny val_result
+    cdef TVMFFIAny advance_result
+    cdef int64_t n, i
+
+    if type_index == kTVMFFIMap:
+        size_func_handle = (<CObject>_FFI_MAP_SIZE).chandle
+        iter_func_handle = (<CObject>_FFI_MAP_FORWARD_ITER).chandle
+    else:
+        size_func_handle = (<CObject>_FFI_DICT_SIZE).chandle
+        iter_func_handle = (<CObject>_FFI_DICT_FORWARD_ITER).chandle
+
+    size_args[0].type_index = type_index
+    size_args[0].v_obj = <TVMFFIObject*>chandle
+    size_result.type_index = kTVMFFINone
+    size_result.v_int64 = 0
+    if TVMFFIFunctionCall(size_func_handle, size_args, 1, &size_result) != 0:
+        return
+
+    n = size_result.v_int64
+    if n == 0:
+        return
+
+    # Get forward iterator
+    iter_args[0].type_index = type_index
+    iter_args[0].v_obj = <TVMFFIObject*>chandle
+    iter_result.type_index = kTVMFFINone
+    iter_result.v_int64 = 0
+    if TVMFFIFunctionCall(iter_func_handle, iter_args, 1, &iter_result) != 0:
+        return
+    iter_handle = <TVMFFIObjectHandle>iter_result.v_obj
+
+    for i in range(n):
+        # Get value (command=1)
+        cmd[0].type_index = kTVMFFIInt
+        cmd[0].v_int64 = 1
+        val_result.type_index = kTVMFFINone
+        val_result.v_int64 = 0
+        if TVMFFIFunctionCall(iter_handle, cmd, 1, &val_result) != 0:
+            TVMFFIObjectDecRef(iter_handle)
+            return
+        if _check_elem_for_stream(&val_result, api, ctx):
+            TVMFFIObjectDecRef(iter_handle)
+            return
+        # Advance (command=2), skip after last entry
+        if i < n - 1:
+            cmd[0].v_int64 = 2
+            advance_result.type_index = kTVMFFINone
+            advance_result.v_int64 = 0
+            if TVMFFIFunctionCall(iter_handle, cmd, 1, &advance_result) != 0:
+                TVMFFIObjectDecRef(iter_handle)
+                return
+
+    TVMFFIObjectDecRef(iter_handle)
+
+
+cdef inline void _scan_container_for_stream(
+    TVMFFIObjectHandle chandle,
+    int32_t type_index,
+    const DLPackExchangeAPI* api,
+    TVMFFIPyCallContext* ctx
+) noexcept:
+    """Scan a container for the first non-CPU tensor to set stream context.
+
+    Best-effort: silently returns on any FFI error (equivalent to no stream 
set).
+    """
+    if type_index == kTVMFFIArray or type_index == kTVMFFIList:
+        _scan_seq_for_stream(chandle, type_index, api, ctx)
+    elif type_index == kTVMFFIMap or type_index == kTVMFFIDict:
+        _scan_map_for_stream(chandle, type_index, api, ctx)
+
+
 cdef inline object make_ret(TVMFFIAny result, const DLPackExchangeAPI* 
c_ctx_dlpack_api = NULL):
     """convert result to return value."""
     cdef int32_t type_index
@@ -72,7 +244,10 @@ cdef inline object make_ret(TVMFFIAny result, const 
DLPackExchangeAPI* c_ctx_dlp
     elif type_index == kTVMFFIOpaquePyObject:
         return make_ret_opaque_object(result)
     elif type_index >= kTVMFFIStaticObjectBegin:
-        return make_ret_object(result)
+        obj = make_ret_object(result)
+        if c_ctx_dlpack_api != NULL and isinstance(obj, CContainerBase):
+            (<CContainerBase>obj)._dlpack_exchange_api = c_ctx_dlpack_api
+        return obj
     # the following code should be optimized to switch case
     if type_index == kTVMFFINone:
         return None
@@ -149,6 +324,26 @@ cdef int TVMFFIPyArgSetterObject_(
     return 0
 
 
+cdef int TVMFFIPyArgSetterContainerObject_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* arg, TVMFFIAny* out
+) except -1:
+    """Setter for container objects (Array, List, Map, Dict).
+
+    Propagates DLPack exchange API tag and scans for stream context.
+    """
+    out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+    out.v_ptr = (<CObject>arg).chandle
+    cdef const DLPackExchangeAPI* api = 
(<CContainerBase>arg)._dlpack_exchange_api
+    if api != NULL:
+        if ctx.dlpack_c_exchange_api == NULL:
+            ctx.dlpack_c_exchange_api = api
+        if ctx.device_type == -1 and api.current_work_stream != NULL:
+            _scan_container_for_stream(
+                (<CObject>arg).chandle, out.type_index, api, ctx)
+    return 0
+
+
 cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
     TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
     PyObject* arg, TVMFFIAny* out
@@ -727,6 +922,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
     if isinstance(arg, Tensor):
         out.func = TVMFFIPyArgSetterTensor_
         return 0
+    if isinstance(arg, CContainerBase):
+        out.func = TVMFFIPyArgSetterContainerObject_
+        return 0
     if isinstance(arg, CObject):
         out.func = TVMFFIPyArgSetterObject_
         return 0
@@ -1147,3 +1345,11 @@ cdef Function _OBJECT_FROM_JSON_GRAPH_STR = 
_get_global_func("ffi.FromJSONGraphS
 cdef Function _OBJECT_TO_JSON_GRAPH_STR = 
_get_global_func("ffi.ToJSONGraphString", True)
 cdef Function _CONSTRUCTOR_ARRAY = _get_global_func("ffi.Array", True)
 cdef Function _CONSTRUCTOR_MAP = _get_global_func("ffi.Map", True)
+cdef Function _FFI_ARRAY_GET_ITEM = _get_global_func("ffi.ArrayGetItem", True)
+cdef Function _FFI_ARRAY_SIZE = _get_global_func("ffi.ArraySize", True)
+cdef Function _FFI_LIST_GET_ITEM = _get_global_func("ffi.ListGetItem", True)
+cdef Function _FFI_LIST_SIZE = _get_global_func("ffi.ListSize", True)
+cdef Function _FFI_MAP_SIZE = _get_global_func("ffi.MapSize", True)
+cdef Function _FFI_MAP_FORWARD_ITER = 
_get_global_func("ffi.MapForwardIterFunctor", True)
+cdef Function _FFI_DICT_SIZE = _get_global_func("ffi.DictSize", True)
+cdef Function _FFI_DICT_FORWARD_ITER = 
_get_global_func("ffi.DictForwardIterFunctor", True)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 0f4fdab..ed61917 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -166,6 +166,34 @@ cdef class CObject:
         self.chandle = chandle
 
 
+cdef class CContainerBase(CObject):
+    """Cython base for container types that support lazy DLPack conversion.
+
+    Stores a ``DLPackExchangeAPI*`` tag so that element access on a
+    returned container can automatically convert ``ffi.Tensor`` to
+    the framework tensor type (e.g. ``torch.Tensor``).
+    """
+    # Raw pointer to the DLPack exchange API struct.  Not ref-counted.
+    #
+    # Lifetime safety: the two sources of this pointer are both
+    # effectively process-lifetime:
+    #
+    # 1. __dlpack_c_exchange_api__ (e.g. torch.Tensor) — points to a
+    #    static struct in the framework's C++ runtime.  The source
+    #    type is kept alive by _DISPATCH_TYPE_KEEP_ALIVE (set in
+    #    TVMFFIPyArgSetterFactory_), which prevents module unloading.
+    #
+    # 2. GetTorchFallbackExchangeAPI() — returns the address of a
+    #    module-level Cython static; lives for the entire process.
+    #
+    # The DLPack spec also mandates that DLPackExchangeAPI* must stay
+    # alive throughout the lifetime of the process (dlpack.h line 600).
+    cdef const DLPackExchangeAPI* _dlpack_exchange_api
+
+    def __cinit__(self):
+        self._dlpack_exchange_api = NULL
+
+
 class _ObjectSlotsMeta(ABCMeta):
     def __new__(mcls, name: str, bases: tuple[type, ...], ns: dict[str, Any], 
**kwargs: Any):
         if "__slots__" not in ns:
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index ad31b75..2751f83 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -542,7 +542,53 @@ TVM_FFI_STATIC_INIT_BLOCK() {
            })
       .def("testing.optional_tensor_view_has_value",
            [](const Optional<TensorView>& t) { return t.has_value(); })
-      .def_method("testing.TestIntPairSum", &TestIntPair::Sum, "Get sum of the 
pair");
+      .def_method("testing.TestIntPairSum", &TestIntPair::Sum, "Get sum of the 
pair")
+      // Container-with-tensor test helpers for DLPack container conversion
+      // NOLINTBEGIN(performance-unnecessary-value-param)
+      .def("testing.make_array_with_tensor", [](Tensor t) -> Array<Any> { 
return {std::move(t)}; })
+      .def("testing.make_array_with_mixed",
+           [](Tensor t, int64_t x) -> Array<Any> { return {std::move(t), x, 
String("hello")}; })
+      .def("testing.make_nested_array_with_tensor",
+           [](const Tensor& t) -> Array<Any> {
+             Array<Any> inner{t, 42};
+             return {std::move(inner), t};
+           })
+      .def("testing.make_list_with_tensor",
+           [](Tensor t, int64_t x) -> List<Any> {
+             List<Any> result;
+             result.push_back(std::move(t));
+             result.push_back(x);
+             return result;
+           })
+      .def("testing.make_map_with_tensor",
+           [](Tensor t) -> Map<String, Any> {
+             Map<String, Any> result;
+             result.Set("tensor", std::move(t));
+             result.Set("value", 42);
+             result.Set("name", String("test"));
+             return result;
+           })
+      .def("testing.make_dict_with_tensor",
+           [](Tensor t) -> Dict<String, Any> {
+             Dict<String, Any> result;
+             result.Set("tensor", std::move(t));
+             result.Set("value", 42);
+             return result;
+           })
+      .def("testing.make_empty_array_with_tensor_input",
+           [](const Tensor& t) -> Array<Any> { return Array<Any>(); })
+      .def("testing.make_nested_map_with_tensor",
+           [](const Tensor& t1, Tensor t2) -> Map<String, Any> {
+             Array<Any> arr{t1, std::move(t2)};
+             Map<String, Any> inner;
+             inner.Set("t", t1);
+             Map<String, Any> result;
+             result.Set("array", std::move(arr));
+             result.Set("map", std::move(inner));
+             result.Set("scalar", 99);
+             return result;
+           });
+  // NOLINTEND(performance-unnecessary-value-param)
 }
 
 }  // namespace ffi
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index a4da236..b4e85c8 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -493,14 +493,14 @@ def test_seq_cross_conv_incompatible_list_to_array() -> 
None:
     """List with incompatible element types should fail when cast to 
Array<int>."""
     lst = tvm_ffi.List(["not", "ints"])
     with pytest.raises(TypeError):
-        testing.schema_id_arr_int(lst)  # type: ignore[arg-type]
+        testing.schema_id_arr_int(lst)  # type: ignore[invalid-argument-type]
 
 
 def test_seq_cross_conv_incompatible_array_to_list() -> None:
     """Array with incompatible element types should fail when cast to 
List<int>."""
     arr = tvm_ffi.Array(["not", "ints"])
     with pytest.raises(TypeError):
-        testing.schema_id_list_int(arr)  # type: ignore[arg-type]
+        testing.schema_id_list_int(arr)  # type: ignore[invalid-argument-type]
 
 
 def test_missing_object() -> None:
@@ -725,11 +725,11 @@ def test_map_cross_conv_incompatible_dict_to_map() -> 
None:
     """Dict with incompatible value types should fail when cast to Map<String, 
int>."""
     d = tvm_ffi.Dict({"a": "not_int", "b": "still_not_int"})
     with pytest.raises(TypeError):
-        testing.schema_id_map_str_int(d)  # type: ignore[arg-type]
+        testing.schema_id_map_str_int(d)  # type: ignore[invalid-argument-type]
 
 
 def test_map_cross_conv_incompatible_map_to_dict() -> None:
     """Map with incompatible value types should fail when cast to Dict<String, 
int>."""
     m = tvm_ffi.Map({"a": "not_int", "b": "still_not_int"})
     with pytest.raises(TypeError):
-        testing.schema_id_dict_str_int(m)  # type: ignore[arg-type]
+        testing.schema_id_dict_str_int(m)  # type: 
ignore[invalid-argument-type]
diff --git a/tests/python/test_container_dlpack_conversion.py 
b/tests/python/test_container_dlpack_conversion.py
new file mode 100644
index 0000000..454ecca
--- /dev/null
+++ b/tests/python/test_container_dlpack_conversion.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for lazy container DLPack conversion when DLPack exchange API is 
active."""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+try:
+    import torch
+    import torch.version
+except ImportError:
+    torch = None  # ty: ignore[invalid-assignment]
+
+import tvm_ffi
+
+pytestmark = pytest.mark.skipif(torch is None, reason="torch is not installed")
+
+
+def test_array_tensor_only() -> None:
+    """Array<Tensor> stays as Array; element access converts to 
torch.Tensor."""
+    assert torch is not None
+    x = torch.arange(8, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_array_with_tensor")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Array)
+    assert len(result) == 1
+    elem = result[0]
+    assert isinstance(elem, torch.Tensor)
+    assert elem.data_ptr() == x.data_ptr()
+
+
+def test_array_mixed() -> None:
+    """Array with Tensor + int + string: lazy conversion on access."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_array_with_mixed")
+    result = f(x, 42)
+    assert isinstance(result, tvm_ffi.Array)
+    assert len(result) == 3
+    assert isinstance(result[0], torch.Tensor)
+    assert result[0].data_ptr() == x.data_ptr()
+    assert result[1] == 42
+    assert result[2] == "hello"
+
+
+def test_array_nested() -> None:
+    """Nested Array<Array<Tensor>>: inner arrays also get tagged."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_nested_array_with_tensor")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Array)
+    assert len(result) == 2
+    # First element is inner array
+    inner = result[0]
+    assert isinstance(inner, tvm_ffi.Array)
+    assert len(inner) == 2
+    assert isinstance(inner[0], torch.Tensor)
+    assert inner[0].data_ptr() == x.data_ptr()
+    assert inner[1] == 42
+    # Second element is a tensor
+    assert isinstance(result[1], torch.Tensor)
+    assert result[1].data_ptr() == x.data_ptr()
+
+
+def test_list_with_tensor() -> None:
+    """List<Any> with tensor: stays as List, elements convert on access."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_list_with_tensor")
+    result = f(x, 7)
+    assert isinstance(result, tvm_ffi.List)
+    assert len(result) == 2
+    assert isinstance(result[0], torch.Tensor)
+    assert result[0].data_ptr() == x.data_ptr()
+    assert result[1] == 7
+
+
+def test_map_with_tensor() -> None:
+    """Map<String, Any> with tensor value: stays as Map, values convert on 
access."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_map_with_tensor")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Map)
+    assert len(result) == 3
+    assert isinstance(result["tensor"], torch.Tensor)
+    assert result["tensor"].data_ptr() == x.data_ptr()
+    assert result["value"] == 42
+    assert result["name"] == "test"
+
+
+def test_dict_with_tensor() -> None:
+    """Dict<String, Any> with tensor value: stays as Dict, values convert on 
access."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_dict_with_tensor")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Dict)
+    assert len(result) == 2
+    assert isinstance(result["tensor"], torch.Tensor)
+    assert result["tensor"].data_ptr() == x.data_ptr()
+    assert result["value"] == 42
+
+
+def test_nested_map_with_array() -> None:
+    """Nested Map with Array values: all containers tagged, lazy conversion on 
access."""
+    assert torch is not None
+    x1 = torch.arange(4, dtype=torch.float32)
+    x2 = torch.arange(8, dtype=torch.int32)
+    f = tvm_ffi.get_global_func("testing.make_nested_map_with_tensor")
+    result = f(x1, x2)
+    assert isinstance(result, tvm_ffi.Map)
+    # "array" -> Array with tagged tensors
+    arr = result["array"]
+    assert isinstance(arr, tvm_ffi.Array)
+    assert len(arr) == 2
+    assert isinstance(arr[0], torch.Tensor)
+    assert isinstance(arr[1], torch.Tensor)
+    # "map" -> nested Map
+    inner_map = result["map"]
+    assert isinstance(inner_map, tvm_ffi.Map)
+    assert isinstance(inner_map["t"], torch.Tensor)
+    # "scalar" -> int
+    assert result["scalar"] == 99
+
+
+def test_empty_array() -> None:
+    """Empty Array with torch input: stays as empty Array."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_empty_array_with_tensor_input")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Array)
+    assert len(result) == 0
+
+
+def test_no_torch_input_no_conversion() -> None:
+    """Without torch tensor input, containers stay as FFI types with no tag."""
+    x = tvm_ffi.from_dlpack(np.arange(4, dtype="float32"))
+    f = tvm_ffi.get_global_func("testing.make_array_with_tensor")
+    result = f(x)
+    # No torch input, so no dlpack API set -> normal FFI Array return
+    assert isinstance(result, tvm_ffi.Array)
+    assert isinstance(result[0], tvm_ffi.Tensor)
+
+
+def test_data_correctness() -> None:
+    """Verify tensor data is correct after lazy container conversion."""
+    assert torch is not None
+    x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_array_with_tensor")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Array)
+    elem = result[0]
+    assert isinstance(elem, torch.Tensor)
+    np.testing.assert_equal(elem.numpy(), x.numpy())
+
+
+def test_echo_bare_tensor_unchanged() -> None:
+    """Existing behavior: bare tensor return still works."""
+    assert torch is not None
+    x = torch.arange(128)
+    fecho = tvm_ffi.get_global_func("testing.echo")
+    y = fecho(x)
+    assert isinstance(y, torch.Tensor)
+    assert y.data_ptr() == x.data_ptr()
+
+
+def test_container_preserves_identity() -> None:
+    """Lazy conversion preserves container identity (can be passed back to 
FFI)."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_array_with_tensor")
+    result = f(x)
+    assert isinstance(result, tvm_ffi.Array)
+    # Pass container back to FFI (echo)
+    fecho = tvm_ffi.get_global_func("testing.echo")
+    echoed = fecho(result)
+    assert isinstance(echoed, tvm_ffi.Array)
+    assert isinstance(echoed[0], torch.Tensor)
+    assert echoed[0].data_ptr() == x.data_ptr()
+
+
+def test_mutable_list_shared_semantics() -> None:
+    """Lazy conversion preserves mutable list shared-reference semantics."""
+    assert torch is not None
+    x = torch.arange(4, dtype=torch.float32)
+    f = tvm_ffi.get_global_func("testing.make_list_with_tensor")
+    result = f(x, 7)
+    assert isinstance(result, tvm_ffi.List)
+    # The result is the actual FFI List, not a detached copy
+    assert result.same_as(result)
diff --git a/tests/python/test_current_work_stream_gpu.py 
b/tests/python/test_current_work_stream_gpu.py
index fcd3b94..910962e 100644
--- a/tests/python/test_current_work_stream_gpu.py
+++ b/tests/python/test_current_work_stream_gpu.py
@@ -44,7 +44,7 @@ else:
 @pytest.mark.skipif(not _HAS_DLPACK_EXCHANGE_API, reason="Requires 
__dlpack_c_exchange_api__")
 def test_current_work_stream_matches_torch_stream() -> None:
     assert torch is not None
-    api_attr = torch.Tensor.__dlpack_c_exchange_api__
+    api_attr = torch.Tensor.__dlpack_c_exchange_api__  # ty: 
ignore[unresolved-attribute]
 
     pythonapi = ctypes.pythonapi
     pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t

Reply via email to