This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 678984f62d [FFI][ABI] Better String and Nested Container handling
(#18311)
678984f62d is described below
commit 678984f62d6aadf87ee75148ee0cf87a72131ef9
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Sep 13 13:52:29 2025 -0400
[FFI][ABI] Better String and Nested Container handling (#18311)
[FFI][ABI][REFACTOR] Better String and nested container handling
This PR improves the overall String/Bytes and nested container handling
It also fixes a bug for temp object recycling when temp object.
- Introduce formal API for string/bytes creation
- Updates the tuple/dict conversion to also preserve the torch stream
- So if a function takes a list of torch.Tensor, torch stream will be
setup in context
- Optimizes recursive argument conversion by moving most logic into c++
---
ffi/include/tvm/ffi/c_api.h | 19 ++
ffi/pyproject.toml | 2 +-
ffi/python/tvm_ffi/_convert.py | 11 +-
ffi/python/tvm_ffi/_optional_torch_c_dlpack.py | 1 -
ffi/python/tvm_ffi/cython/base.pxi | 11 +
ffi/python/tvm_ffi/cython/function.pxi | 267 ++++++++++++++++-----
ffi/python/tvm_ffi/cython/object.pxi | 11 +-
ffi/python/tvm_ffi/cython/string.pxi | 5 -
ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 101 +++++++-
ffi/src/ffi/object.cc | 18 ++
ffi/tests/python/test_function.py | 33 +++
ffi/tests/python/test_load_inline.py | 13 +-
src/runtime/disco/protocol.h | 6 +-
src/runtime/minrpc/rpc_reference.h | 4 +-
src/runtime/rpc/rpc_endpoint.cc | 44 +++-
15 files changed, 432 insertions(+), 114 deletions(-)
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index a53dac4d00..f13f820b7f 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -555,6 +555,25 @@ TVM_FFI_DLL int
TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* from,
*/
TVM_FFI_DLL int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from,
DLManagedTensorVersioned** out);
+//---------------------------------------------------------------
+// Section: string/bytes support APIs.
+// These APIs are used to simplify the string/bytes construction
+//---------------------------------------------------------------
+/*!
+ * \brief Reinterpret the content of TVMFFIByteArray to String.
+ * \param input The TVMFFIByteArray to convert.
+ * \param out The output String owned by the caller, maybe a SmallStr or a Str
object.
+ * \return 0 on success, nonzero on failure.
+ */
+TVM_FFI_DLL int TVMFFIStringFromByteArray(const TVMFFIByteArray* input,
TVMFFIAny* out);
+
+/*!
+ * \brief Reinterpret the content of TVMFFIByteArray to Bytes.
+ * \param input The TVMFFIByteArray to convert.
+ * \param out The output Bytes owned by the caller, maybe a SmallBytes or a
Bytes object.
+ * \return 0 on success, nonzero on failure.
+ */
+TVM_FFI_DLL int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input,
TVMFFIAny* out);
//---------------------------------------------------------------
// Section: dtype string support APIs.
diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml
index 8c146f41c4..cc2df03f0a 100644
--- a/ffi/pyproject.toml
+++ b/ffi/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0a12"
+version = "0.1.0a13"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/ffi/python/tvm_ffi/_convert.py b/ffi/python/tvm_ffi/_convert.py
index b1b972633d..a0b6c1b117 100644
--- a/ffi/python/tvm_ffi/_convert.py
+++ b/ffi/python/tvm_ffi/_convert.py
@@ -40,13 +40,9 @@ def convert(value: Any) -> Any:
automatically converted. So this function is mainly
only used in internal or testing scenarios.
"""
- if isinstance(value, core.Object):
+ if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)):
return value
- elif isinstance(value, core.PyNativeObject):
- return value
- elif isinstance(value, (bool, Number)):
- return value
- elif isinstance(value, (list, tuple)):
+ elif isinstance(value, (tuple, list)):
return container.Array(value)
elif isinstance(value, dict):
return container.Map(value)
@@ -67,6 +63,3 @@ def convert(value: Any) -> Any:
else:
# in this case, it is an opaque python object
return core._convert_to_opaque_object(value)
-
-
-core._set_func_convert_to_object(convert)
diff --git a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py
b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py
index fc5851af17..f44855247a 100644
--- a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -384,7 +384,6 @@ int64_t TorchDLPackTensorAllocatorPtr() {
],
extra_cflags=["-O3"],
extra_include_paths=libinfo.include_paths() +
cpp_extension.include_paths("cuda"),
- verbose=True,
)
# set the dlpack related flags
torch.Tensor.__c_dlpack_from_pyobject__ =
mod.TorchDLPackFromPyObjectPtr()
diff --git a/ffi/python/tvm_ffi/cython/base.pxi
b/ffi/python/tvm_ffi/cython/base.pxi
index fdb06f5105..ef583c7529 100644
--- a/ffi/python/tvm_ffi/cython/base.pxi
+++ b/ffi/python/tvm_ffi/cython/base.pxi
@@ -212,6 +212,8 @@ cdef extern from "tvm/ffi/c_api.h":
TVMFFIByteArray* traceback) nogil
int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex)
nogil
+ int TVMFFIStringFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out)
nogil
+ int TVMFFIBytesFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil
const TVMFFIByteArray* TVMFFITraceback(
@@ -284,6 +286,15 @@ cdef extern from "tvm_ffi_python_helpers.h":
DLPackToPyObject* out_dlpack_importer
) except -1
+ int TVMFFIPyConstructorCall(
+ TVMFFIPyArgSetterFactory setter_factory,
+ void* chandle,
+ PyObject* py_arg_tuple,
+ TVMFFIAny* result,
+ int* c_api_ret_code,
+ TVMFFIPyCallContext* parent_ctx
+ ) except -1
+
int TVMFFIPyCallFieldSetter(
TVMFFIPyArgSetterFactory setter_factory,
TVMFFIFieldSetter field_setter,
diff --git a/ffi/python/tvm_ffi/cython/function.pxi
b/ffi/python/tvm_ffi/cython/function.pxi
index 9b86054b71..71c9522ddb 100644
--- a/ffi/python/tvm_ffi/cython/function.pxi
+++ b/ffi/python/tvm_ffi/cython/function.pxi
@@ -88,6 +88,27 @@ cdef inline object make_ret(TVMFFIAny result,
DLPackToPyObject c_dlpack_to_pyobj
raise ValueError("Unhandled type index %d" % type_index)
+##----------------------------------------------------------------------------
+## Helper to simplify calling constructor
+##----------------------------------------------------------------------------
+cdef inline int ConstructorCall(void* constructor_handle,
+ PyObject* py_arg_tuple,
+ void** handle,
+ TVMFFIPyCallContext* parent_ctx) except -1:
+ """Call contructor of a handle function"""
+ cdef TVMFFIAny result
+ cdef int c_api_ret_code
+ # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone
+ result.type_index = kTVMFFINone
+ result.v_int64 = 0
+ TVMFFIPyConstructorCall(
+ TVMFFIPyArgSetterFactory_, constructor_handle, py_arg_tuple, &result,
&c_api_ret_code,
+ parent_ctx
+ )
+ CHECK_CALL(c_api_ret_code)
+ handle[0] = result.v_ptr
+ return 0
+
##----------------------------------------------------------------------------
## Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_
##----------------------------------------------------------------------------
@@ -244,18 +265,33 @@ cdef int TVMFFIPyArgSetterStr_(
) except -1:
"""Setter for str"""
cdef object arg = <object>py_arg
+ cdef bytes tstr = arg.encode("utf-8")
+ cdef char* data
+ cdef Py_ssize_t size
+ cdef TVMFFIByteArray cdata
+
+ PyBytes_AsStringAndSize(tstr, &data, &size)
+ cdata.data = data
+ cdata.size = size
+ CHECK_CALL(TVMFFIStringFromByteArray(&cdata, out))
+ if out.type_index >= kTVMFFIStaticObjectBegin:
+ TVMFFIPyPushTempFFIObject(ctx, out.v_ptr)
+ return 0
+
- if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
+cdef int TVMFFIPyArgSetterPyNativeObjectStr_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Specially handle String as its __tvm_ffi_object__ may be empty"""
+ cdef object arg = <object>py_arg
+ # need to check if the arg is a large string returned from ffi
+ if arg.__tvm_ffi_object__ is not None:
arg = arg.__tvm_ffi_object__
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
out.v_ptr = (<Object>arg).chandle
return 0
-
- tstr = c_str(arg)
- out.type_index = kTVMFFIRawStr
- out.v_c_str = tstr
- TVMFFIPyPushTempPyObject(ctx, <PyObject*>tstr)
- return 0
+ return TVMFFIPyArgSetterStr_(handle, ctx, py_arg, out)
cdef int TVMFFIPyArgSetterBytes_(
@@ -265,17 +301,50 @@ cdef int TVMFFIPyArgSetterBytes_(
"""Setter for bytes"""
cdef object arg = <object>py_arg
- if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
+ if isinstance(arg, bytearray):
+ arg = bytes(arg)
+
+ cdef char* data
+ cdef Py_ssize_t size
+ cdef TVMFFIByteArray cdata
+
+ PyBytes_AsStringAndSize(arg, &data, &size)
+ cdata.data = data
+ cdata.size = size
+ CHECK_CALL(TVMFFIBytesFromByteArray(&cdata, out))
+
+ if out.type_index >= kTVMFFIStaticObjectBegin:
+ TVMFFIPyPushTempFFIObject(ctx, out.v_ptr)
+ return 0
+
+
+cdef int TVMFFIPyArgSetterPyNativeObjectBytes_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Specially handle Bytes as its __tvm_ffi_object__ may be empty"""
+ cdef object arg = <object>py_arg
+ # need to check if the arg is a large bytes returned from ffi
+ if arg.__tvm_ffi_object__ is not None:
arg = arg.__tvm_ffi_object__
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
out.v_ptr = (<Object>arg).chandle
return 0
+ return TVMFFIPyArgSetterBytes_(handle, ctx, py_arg, out)
- arg = ByteArrayArg(arg)
- out.type_index = kTVMFFIByteArrayPtr
- out.v_int64 = 0
- out.v_ptr = (<ByteArrayArg>arg).cptr()
- TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
+
+cdef int TVMFFIPyArgSetterPyNativeObjectGeneral_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Specially handle Bytes as its __tvm_ffi_object__ may be empty"""
+ cdef object arg = <object>py_arg
+ if arg.__tvm_ffi_object__ is None:
+ raise ValueError(f"__tvm_ffi_object__ is None for {type(arg)}")
+ assert arg.__tvm_ffi_object__ is not None
+ arg = arg.__tvm_ffi_object__
+ out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
+ out.v_ptr = (<Object>arg).chandle
return 0
@@ -306,10 +375,11 @@ cdef int TVMFFIPyArgSetterCallable_(
) except -1:
"""Setter for Callable"""
cdef object arg = <object>py_arg
- arg = _convert_to_ffi_func(arg)
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
- TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
+ cdef TVMFFIObjectHandle chandle
+ _convert_to_ffi_func_handle(arg, &chandle)
+ out.type_index = TVMFFIObjectGetTypeIndex(chandle)
+ out.v_ptr = chandle
+ TVMFFIPyPushTempFFIObject(ctx, chandle)
return 0
@@ -326,27 +396,79 @@ cdef int TVMFFIPyArgSetterException_(
return 0
+cdef int TVMFFIPyArgSetterTuple_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for Tuple"""
+ # recursively construct a new tuple
+ cdef TVMFFIObjectHandle chandle
+ ConstructorCall(_CONSTRUCTOR_ARRAY.chandle, py_arg, &chandle, ctx)
+ out.type_index = TVMFFIObjectGetTypeIndex(chandle)
+ out.v_ptr = chandle
+ TVMFFIPyPushTempFFIObject(ctx, chandle)
+ return 0
+
+
+cdef int TVMFFIPyArgSetterTupleLike_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for TupleLike"""
+ # recursively construct a new tuple
+ cdef tuple tuple_arg = tuple(<object>py_arg)
+ cdef TVMFFIObjectHandle chandle
+ ConstructorCall(_CONSTRUCTOR_ARRAY.chandle, <PyObject*>tuple_arg,
&chandle, ctx)
+ out.type_index = TVMFFIObjectGetTypeIndex(chandle)
+ out.v_ptr = chandle
+ TVMFFIPyPushTempFFIObject(ctx, chandle)
+ return 0
+
+
+cdef int TVMFFIPyArgSetterMap_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for Map"""
+ # recursively construct a new map
+ cdef dict dict_arg = <dict>py_arg
+ cdef list list_kvs = []
+ for k, v in dict_arg.items():
+ list_kvs.append(k)
+ list_kvs.append(v)
+ cdef tuple_arg_kvs = tuple(list_kvs)
+ cdef TVMFFIObjectHandle chandle
+ ConstructorCall(_CONSTRUCTOR_MAP.chandle, <PyObject*>tuple_arg_kvs,
&chandle, ctx)
+ out.type_index = TVMFFIObjectGetTypeIndex(chandle)
+ out.v_ptr = chandle
+ TVMFFIPyPushTempFFIObject(ctx, chandle)
+ return 0
+
+
+cdef int TVMFFIPyArgSetterObjectConvertible_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for ObjectConvertible"""
+ # recursively construct a new map
+ cdef object arg = <object>py_arg
+ arg = arg.asobject()
+ out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
+ out.v_ptr = (<Object>arg).chandle
+ TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
+
+
cdef int TVMFFIPyArgSetterFallback_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
) except -1:
"""Fallback setter for all other types"""
cdef object arg = <object>py_arg
- # fallback must contain PyNativeObject check
- if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
- arg = arg.__tvm_ffi_object__
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
- elif isinstance(arg, (list, tuple, dict, ObjectConvertible)):
- arg = _FUNC_CONVERT_TO_OBJECT(arg)
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
- TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
- else:
- arg = _convert_to_opaque_object(arg)
- out.type_index = kTVMFFIOpaquePyObject
- out.v_ptr = (<Object>arg).chandle
- TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
+ cdef TVMFFIObjectHandle chandle
+ _convert_to_opaque_object_handle(arg, &chandle)
+ out.type_index = kTVMFFIOpaquePyObject
+ out.v_ptr = chandle
+ TVMFFIPyPushTempFFIObject(ctx, chandle)
cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out)
except -1:
@@ -407,12 +529,32 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
if isinstance(arg, _CLASS_DEVICE):
out.func = TVMFFIPyArgSetterDevice_
return 0
+ if isinstance(arg, PyNativeObject):
+ # check for PyNativeObject
+ # this check must happen before str/bytes/tuple
+ if isinstance(arg, str):
+ out.func = TVMFFIPyArgSetterPyNativeObjectStr_
+ return 0
+ if isinstance(arg, bytes):
+ out.func = TVMFFIPyArgSetterPyNativeObjectBytes_
+ return 0
+ out.func = TVMFFIPyArgSetterPyNativeObjectGeneral_
+ return 0
if isinstance(arg, str):
out.func = TVMFFIPyArgSetterStr_
return 0
if isinstance(arg, (bytes, bytearray)):
out.func = TVMFFIPyArgSetterBytes_
return 0
+ if isinstance(arg, tuple):
+ out.func = TVMFFIPyArgSetterTuple_
+ return 0
+ if isinstance(arg, list):
+ out.func = TVMFFIPyArgSetterTupleLike_
+ return 0
+ if isinstance(arg, dict):
+ out.func = TVMFFIPyArgSetterMap_
+ return 0
if isinstance(arg, ctypes.c_void_p):
out.func = TVMFFIPyArgSetterCtypesVoidPtr_
return 0
@@ -422,6 +564,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
if isinstance(arg, Exception):
out.func = TVMFFIPyArgSetterException_
return 0
+ if isinstance(arg, ObjectConvertible):
+ out.func = TVMFFIPyArgSetterObjectConvertible_
+ return 0
# default to opaque object
out.func = TVMFFIPyArgSetterFallback_
return 0
@@ -429,24 +574,6 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
#---------------------------------------------------------------------------------------------
## Implementation of function calling
#---------------------------------------------------------------------------------------------
-cdef inline int ConstructorCall(void* constructor_handle,
- tuple args,
- void** handle) except -1:
- """Call contructor of a handle function"""
- cdef TVMFFIAny result
- cdef int c_api_ret_code
- # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone
- result.type_index = kTVMFFINone
- result.v_int64 = 0
- TVMFFIPyFuncCall(
- TVMFFIPyArgSetterFactory_, constructor_handle, <PyObject*>args,
&result, &c_api_ret_code,
- False, NULL
- )
- CHECK_CALL(c_api_ret_code)
- handle[0] = result.v_ptr
- return 0
-
-
cdef class Function(Object):
"""Python class that wraps a function with tvm-ffi ABI.
@@ -670,29 +797,45 @@ cdef int tvm_ffi_callback(void* context,
return -1
-def _convert_to_ffi_func(object pyfunc):
- """Convert a python function to TVM FFI function"""
- cdef TVMFFIObjectHandle chandle
+cdef inline int _convert_to_ffi_func_handle(
+ object pyfunc, TVMFFIObjectHandle* out_handle
+) except -1:
+ """Convert a python function to TVM FFI function handle"""
Py_INCREF(pyfunc)
CHECK_CALL(TVMFFIFunctionCreate(
<void*>(pyfunc),
tvm_ffi_callback,
tvm_ffi_pyobject_deleter,
- &chandle))
+ out_handle))
+ return 0
+
+
+def _convert_to_ffi_func(object pyfunc):
+ """Convert a python function to TVM FFI function"""
+ cdef TVMFFIObjectHandle chandle
+ _convert_to_ffi_func_handle(pyfunc, &chandle)
ret = Function.__new__(Function)
(<Object>ret).chandle = chandle
return ret
-def _convert_to_opaque_object(object pyobject):
- """Convert a python object to TVM FFI opaque object"""
- cdef TVMFFIObjectHandle chandle
+cdef inline int _convert_to_opaque_object_handle(
+ object pyobject, TVMFFIObjectHandle* out_handle
+) except -1:
+ """Convert a python object to TVM FFI opaque object handle"""
Py_INCREF(pyobject)
CHECK_CALL(TVMFFIObjectCreateOpaque(
<void*>(pyobject),
kTVMFFIOpaquePyObject,
tvm_ffi_pyobject_deleter,
- &chandle))
+ out_handle))
+ return 0
+
+
+def _convert_to_opaque_object(object pyobject):
+ """Convert a python object to TVM FFI opaque object"""
+ cdef TVMFFIObjectHandle chandle
+ _convert_to_opaque_object_handle(pyobject, &chandle)
ret = OpaquePyObject.__new__(OpaquePyObject)
(<Object>ret).chandle = chandle
return ret
@@ -704,7 +847,7 @@ def _print_debug_info():
print(f"TVMFFIPyGetDispatchMapSize: {size}")
-_STR_CONSTRUCTOR = _get_global_func("ffi.String", False)
-_BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False)
-_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True)
-_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True)
+cdef Function _OBJECT_FROM_JSON_GRAPH_STR =
_get_global_func("ffi.FromJSONGraphString", True)
+cdef Function _OBJECT_TO_JSON_GRAPH_STR =
_get_global_func("ffi.ToJSONGraphString", True)
+cdef Function _CONSTRUCTOR_ARRAY = _get_global_func("ffi.Array", True)
+cdef Function _CONSTRUCTOR_MAP = _get_global_func("ffi.Map", True)
diff --git a/ffi/python/tvm_ffi/cython/object.pxi
b/ffi/python/tvm_ffi/cython/object.pxi
index 2a306e01ee..1d026b250f 100644
--- a/ffi/python/tvm_ffi/cython/object.pxi
+++ b/ffi/python/tvm_ffi/cython/object.pxi
@@ -17,17 +17,12 @@
import warnings
_CLASS_OBJECT = None
-_FUNC_CONVERT_TO_OBJECT = None
def _set_class_object(cls):
global _CLASS_OBJECT
_CLASS_OBJECT = cls
-def _set_func_convert_to_object(func):
- global _FUNC_CONVERT_TO_OBJECT
- _FUNC_CONVERT_TO_OBJECT = func
-
def __object_repr__(obj):
"""Object repr function that can be overridden by assigning to it"""
@@ -39,10 +34,6 @@ def _new_object(cls):
return cls.__new__(cls)
-_OBJECT_FROM_JSON_GRAPH_STR = None
-_OBJECT_TO_JSON_GRAPH_STR = None
-
-
class ObjectConvertible:
"""Base class for all classes that can be converted to object."""
@@ -144,7 +135,7 @@ cdef class Object:
self.chandle = NULL
cdef void* chandle
ConstructorCall(
- (<Object>fconstructor).chandle, args, &chandle)
+ (<Object>fconstructor).chandle, <PyObject*>args, &chandle, NULL)
self.chandle = chandle
def same_as(self, other):
diff --git a/ffi/python/tvm_ffi/cython/string.pxi
b/ffi/python/tvm_ffi/cython/string.pxi
index 4ab5c48ce0..0737259f22 100644
--- a/ffi/python/tvm_ffi/cython/string.pxi
+++ b/ffi/python/tvm_ffi/cython/string.pxi
@@ -78,8 +78,3 @@ class Bytes(bytes, PyNativeObject):
_register_object_by_index(kTVMFFIBytes, Bytes)
-
-# We special handle str/bytes constructor in cython to avoid extra cyclic deps
-# as the str/bytes construction must be done in the inner loop of function call
-_STR_CONSTRUCTOR = None
-_BYTES_CONSTRUCTOR = None
diff --git a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 87b426829d..325b878c4f 100644
--- a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -226,10 +226,7 @@ class TVMFFIPyCallManager {
try {
// recycle the temporary arguments if any
for (int i = 0; i < this->num_temp_ffi_objects; ++i) {
- TVMFFIObject* obj =
static_cast<TVMFFIObject*>(this->temp_ffi_objects[i]);
- if (obj->deleter != nullptr) {
- obj->deleter(obj, kTVMFFIObjectDeleterFlagBitMaskBoth);
- }
+ TVMFFIObjectDecRef(this->temp_ffi_objects[i]);
}
for (int i = 0; i < this->num_temp_py_objects; ++i) {
Py_DecRef(static_cast<PyObject*>(this->temp_py_objects[i]));
@@ -270,9 +267,9 @@ class TVMFFIPyCallManager {
* \return 0 on when there is no python error, -1 on python error
* \note When an error happens on FFI side, we should return 0 and set
c_api_ret_code
*/
- int Call(TVMFFIPyArgSetterFactory setter_factory, void* func_handle,
PyObject* py_arg_tuple,
- TVMFFIAny* result, int* c_api_ret_code, bool release_gil,
- DLPackToPyObject* optional_out_dlpack_importer) {
+ int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle,
PyObject* py_arg_tuple,
+ TVMFFIAny* result, int* c_api_ret_code, bool release_gil,
+ DLPackToPyObject* optional_out_dlpack_importer) {
int64_t num_args = PyTuple_Size(py_arg_tuple);
if (num_args == -1) return -1;
try {
@@ -331,6 +328,64 @@ class TVMFFIPyCallManager {
}
}
+ /*
+ * \brief Call a constructor with a variable number of arguments
+ *
+ * This function is similar to FuncCall, but it will not set the
+ * stream and tensor allocator, instead, it will synchronize the
TVMFFIPyCallContext
+ * with the parent context. This behavior is needed for nested conversion of
arguments
+ * where detected argument setting needs to be synchronized with final call.
+ *
+ * This function will also not release the GIL since constructor call is
usually cheap.
+ *
+ * \param setter_factory The factory function to create the setter
+ * \param func_handle The handle of the constructor to call
+ * \param py_arg_tuple The arguments to the constructor
+ * \param result The result of the constructor
+ * \param c_api_ret_code The return code of the constructor
+ * \param parent_ctx The parent call context to
+ * \return 0 on success, -1 on failure
+ */
+ int ConstructorCall(TVMFFIPyArgSetterFactory setter_factory, void*
func_handle,
+ PyObject* py_arg_tuple, TVMFFIAny* result, int*
c_api_ret_code,
+ TVMFFIPyCallContext* parent_ctx) {
+ int64_t num_args = PyTuple_Size(py_arg_tuple);
+ if (num_args == -1) return -1;
+ try {
+ // allocate a call stack
+ CallStack ctx(this, num_args);
+ // Iterate over the arguments and set them
+ for (int64_t i = 0; i < num_args; ++i) {
+ PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
+ TVMFFIAny* c_arg = ctx.packed_args + i;
+ if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
+ }
+ c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args,
num_args, result);
+ // propagate the call context to the parent context
+ if (parent_ctx != nullptr) {
+ // stream and current device information
+ if (parent_ctx->device_type == -1) {
+ parent_ctx->device_type = ctx.device_type;
+ parent_ctx->device_id = ctx.device_id;
+ parent_ctx->stream = ctx.stream;
+ }
+ // DLPack allocator
+ if (parent_ctx->c_dlpack_tensor_allocator == nullptr) {
+ parent_ctx->c_dlpack_tensor_allocator =
ctx.c_dlpack_tensor_allocator;
+ }
+ // DLPack importer
+ if (parent_ctx->c_dlpack_to_pyobject == nullptr) {
+ parent_ctx->c_dlpack_to_pyobject = ctx.c_dlpack_to_pyobject;
+ }
+ }
+ return 0;
+ } catch (const std::exception& ex) {
+ // very rare, catch c++ exception and set python error
+ PyErr_SetString(PyExc_RuntimeError, ex.what());
+ return -1;
+ }
+ }
+
int SetField(TVMFFIPyArgSetterFactory setter_factory, TVMFFIFieldSetter
field_setter,
void* field_ptr, PyObject* py_arg, int* c_api_ret_code) {
try {
@@ -430,8 +485,36 @@ inline int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory
setter_factory, void* func_
PyObject* py_arg_tuple, TVMFFIAny* result, int*
c_api_ret_code,
bool release_gil = true,
DLPackToPyObject* out_dlpack_importer = nullptr) {
- return TVMFFIPyCallManager::ThreadLocal()->Call(setter_factory, func_handle,
py_arg_tuple, result,
- c_api_ret_code, release_gil,
out_dlpack_importer);
+ return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory,
func_handle, py_arg_tuple,
+ result, c_api_ret_code,
release_gil,
+ out_dlpack_importer);
+}
+
+/*!
+ * \brief Call a constructor function with a variable number of arguments
+ *
+ * This function is similar to TVMFFIPyFuncCall, but it will not set the
+ * stream and tensor allocator. Instead, it will synchronize the
TVMFFIPyCallContext
+ * with the parent context. This behavior is needed for nested conversion of
arguments
+ * where detected argument settings need to be synchronized with the final
call.
+ *
+ * This function will also not release the GIL since constructor call is
usually cheap.
+ *
+ * \param setter_factory The factory function to create the setter
+ * \param func_handle The handle of the function to call
+ * \param py_arg_tuple The arguments to the constructor
+ * \param result The result of the constructor
+ * \param c_api_ret_code The return code of the constructor
+ * \param parent_ctx The parent call context
+ * \param release_gil Whether to release the GIL
+ * \param out_dlpack_exporter The DLPack exporter to be used for the result
+ * \return 0 on success, nonzero on failure
+ */
+inline int TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory setter_factory,
void* func_handle,
+ PyObject* py_arg_tuple, TVMFFIAny* result,
int* c_api_ret_code,
+ TVMFFIPyCallContext* parent_ctx) {
+ return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall(
+ setter_factory, func_handle, py_arg_tuple, result, c_api_ret_code,
parent_ctx);
}
/*!
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 9f554e3356..292c8e913f 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -493,3 +493,21 @@ const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t
type_index) {
return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index);
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo);
}
+
+// string APIs, we blend into object.cc to keep things simple
+int TVMFFIStringFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ // must set to none first
+ out->type_index = kTVMFFINone;
+
tvm::ffi::TypeTraits<tvm::ffi::String>::MoveToAny(tvm::ffi::String(input->data,
input->size),
+ out);
+ TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ // must set to none first
+ out->type_index = kTVMFFINone;
+
tvm::ffi::TypeTraits<tvm::ffi::Bytes>::MoveToAny(tvm::ffi::Bytes(input->data,
input->size), out);
+ TVM_FFI_SAFE_CALL_END();
+}
diff --git a/ffi/tests/python/test_function.py
b/ffi/tests/python/test_function.py
index dfe22a1bad..b5a1da4f7d 100644
--- a/ffi/tests/python/test_function.py
+++ b/ffi/tests/python/test_function.py
@@ -97,6 +97,39 @@ def test_return_raw_str_bytes():
assert tvm_ffi.convert(lambda: bytearray(b"hello"))() == b"hello"
+def test_string_bytes_passing():
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ use_count = tvm_ffi.get_global_func("testing.object_use_count")
+ # small string
+ assert fecho("hello") == "hello"
+ # large string
+ x = "hello" * 100
+ y = fecho(x)
+ assert y == x
+ assert y.__tvm_ffi_object__ is not None
+ use_count(y) == 1
+ # small bytes
+ assert fecho(b"hello") == b"hello"
+ # large bytes
+ x = b"hello" * 100
+ y = fecho(x)
+ assert y == x
+ assert y.__tvm_ffi_object__ is not None
+ fecho(y) == 1
+
+
+def test_nested_container_passing():
+ # test and make sure our ref counting is correct
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ use_count = tvm_ffi.get_global_func("testing.object_use_count")
+ obj = tvm_ffi.convert((1, 2, 3))
+ assert use_count(obj) == 1
+ y = fecho([obj, {"a": 1, "b": obj}])
+ assert use_count(y) == 1
+ assert use_count(obj) == 3
+ assert use_count(y[1]) == 2
+
+
def test_pyfunc_convert():
def add(a, b):
return a + b
diff --git a/ffi/tests/python/test_load_inline.py
b/ffi/tests/python/test_load_inline.py
index 2aa01a62ee..0277803730 100644
--- a/ffi/tests/python/test_load_inline.py
+++ b/ffi/tests/python/test_load_inline.py
@@ -207,13 +207,10 @@ def test_load_inline_cuda_with_env_tensor_allocator():
pytest.skip("Torch does not support __c_dlpack_tensor_allocator__")
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
- cpp_sources=r"""
- #include <tvm/ffi/container/tensor.h>
-
- tvm::ffi::Tensor return_add_one(DLTensor* x);
- """,
cuda_sources=r"""
#include <tvm/ffi/container/tensor.h>
+ #include <tvm/ffi/container/tuple.h>
+ #include <tvm/ffi/container/map.h>
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -223,7 +220,8 @@ def test_load_inline_cuda_with_env_tensor_allocator():
}
namespace ffi = tvm::ffi;
- ffi::Tensor return_add_one(DLTensor* x) {
+ ffi::Tensor return_add_one(ffi::Map<ffi::String,
ffi::Tuple<ffi::Tensor>> kwargs) {
+ ffi::Tensor x = kwargs["x"].get<0>();
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -251,7 +249,8 @@ def test_load_inline_cuda_with_env_tensor_allocator():
if torch is not None:
x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32,
device="cuda")
- y_cuda = mod.return_add_one(x_cuda)
+ # test support for nested container passing
+ y_cuda = mod.return_add_one({"x": [x_cuda]})
assert isinstance(y_cuda, torch.Tensor)
assert y_cuda.shape == (5,)
assert y_cuda.dtype == torch.float32
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
index e36935c8d2..067a4f0d4a 100644
--- a/src/runtime/disco/protocol.h
+++ b/src/runtime/disco/protocol.h
@@ -49,7 +49,7 @@ struct DiscoProtocol {
/*! \brief Recycle all the memory used in the arena */
inline void RecycleAll() {
- this->object_arena_.clear();
+ this->any_arena_.clear();
this->arena_.RecycleAll();
}
@@ -81,7 +81,7 @@ struct DiscoProtocol {
}
support::Arena arena_;
- std::vector<Any> object_arena_;
+ std::vector<Any> any_arena_;
friend struct RPCReference;
};
@@ -213,7 +213,7 @@ inline void
DiscoProtocol<SubClassType>::ReadFFIAny(TVMFFIAny* out) {
<< Object::TypeIndex2Key(type_index) << " (type_index = " <<
type_index << ")";
}
*reinterpret_cast<ffi::AnyView*>(out) = result;
- object_arena_.push_back(result);
+ any_arena_.push_back(result);
}
inline std::string DiscoDebugObject::SaveToStr() const {
diff --git a/src/runtime/minrpc/rpc_reference.h
b/src/runtime/minrpc/rpc_reference.h
index ee08ad12c7..8b21b24927 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -472,7 +472,9 @@ struct RPCReference {
break;
}
default: {
- if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin ||
+ type_index == ffi::TypeIndex::kTVMFFISmallStr ||
+ type_index == ffi::TypeIndex::kTVMFFISmallBytes) {
channel->ReadFFIAny(&(packed_args[i]));
} else {
channel->ThrowError(RPCServerStatus::kUnknownTypeIndex);
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index c51484b279..0778b55394 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -171,6 +171,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
for (int i = 0; i < args.size(); ++i) {
if (args[i] == nullptr) continue;
if (args[i].type_index() == ffi::TypeIndex::kTVMFFIModule) continue;
+ if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
+ args[i].type_index() == ffi::TypeIndex::kTVMFFISmallBytes)
+ continue;
+ if (args[i].type_index() == ffi::TypeIndex::kTVMFFIStr ||
+ args[i].type_index() == ffi::TypeIndex::kTVMFFIBytes)
+ continue;
if (const Object* obj = args[i].as<Object>()) {
if (!obj->IsInstance<RPCObjectRefObj>()) {
LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
<< obj->GetTypeKey()
@@ -221,14 +227,20 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
void WriteFFIAny(const TVMFFIAny* in) {
// NOTE: for now all remote object are encoded as RPCObjectRef
// follow the same disco protocol in case we would like to upgrade later
- //
- // Rationale note: Only handle remote object allows the same mechanism to
work for minRPC
- // which is needed for wasm and other env that goes through C API
+ // TODO(tqchen): consider merge with disco protocol
const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(in);
if (const auto* ref = any_view_ptr->as<RPCObjectRefObj>()) {
this->template Write<uint32_t>(runtime::TypeIndex::kRuntimeRPCObjectRef);
uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle());
this->template Write<int64_t>(handle);
+ } else if (auto opt_str = any_view_ptr->as<ffi::String>()) {
+ this->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIStr);
+ this->template Write<uint64_t>((*opt_str).size());
+ this->template WriteArray<char>((*opt_str).data(), (*opt_str).size());
+ } else if (auto opt_bytes = any_view_ptr->as<ffi::Bytes>()) {
+ this->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIBytes);
+ this->template Write<uint64_t>((*opt_bytes).size());
+ this->template WriteArray<char>((*opt_bytes).data(),
(*opt_bytes).size());
} else {
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling
convention: "
<< any_view_ptr->GetTypeKey() << " (type_index = " <<
any_view_ptr->type_index()
@@ -239,6 +251,10 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
const AnyView* any_view_ptr = reinterpret_cast<const AnyView*>(in);
if (any_view_ptr->as<RPCObjectRefObj>()) {
return sizeof(uint32_t) + sizeof(int64_t);
+ } else if (auto opt_str = any_view_ptr->as<ffi::String>()) {
+ return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_str).size();
+ } else if (auto opt_bytes = any_view_ptr->as<ffi::Bytes>()) {
+ return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_bytes).size();
} else {
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling
convention: "
<< any_view_ptr->GetTypeKey() << " (type_index = " <<
any_view_ptr->type_index()
@@ -266,7 +282,23 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
// Legacy ABI translation
// TODO(tqchen): remove this once we have upgraded to new ABI
*reinterpret_cast<AnyView*>(out) = rpc_obj;
- object_arena_.push_back(rpc_obj);
+ any_arena_.emplace_back(rpc_obj);
+ } else if (type_index == ffi::TypeIndex::kTVMFFIStr) {
+ uint64_t size;
+ this->template Read<uint64_t>(&size);
+ std::string data(size, '\0');
+ this->template ReadArray<char>(data.data(), size);
+ ffi::String ret(std::move(data));
+ *reinterpret_cast<AnyView*>(out) = ret;
+ any_arena_.emplace_back(ret);
+ } else if (type_index == ffi::TypeIndex::kTVMFFIBytes) {
+ uint64_t size;
+ this->template Read<uint64_t>(&size);
+ std::string data(size, '\0');
+ this->template ReadArray<char>(data.data(), size);
+ ffi::Bytes ret(std::move(data));
+ *reinterpret_cast<AnyView*>(out) = ret;
+ any_arena_.emplace_back(ret);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
<< Object::TypeIndex2Key(type_index) << " (type_index = " <<
type_index << ")";
@@ -285,7 +317,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
/*! \brief Recycle all the memory used in the arena */
void RecycleAll() {
- this->object_arena_.clear();
+ this->any_arena_.clear();
this->arena_.RecycleAll();
}
@@ -310,7 +342,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
// Internal arena
support::Arena arena_;
// internal arena for temp objects
- std::vector<ObjectRef> object_arena_;
+ std::vector<ffi::Any> any_arena_;
// State switcher
void SwitchToState(State state) {