This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 54f527f feat: Support Array Concat in Python (#45)
54f527f is described below
commit 54f527f4d3d1ae1b950e0fe1a52ccb7bfc6df249
Author: Junru Shao <[email protected]>
AuthorDate: Tue Sep 23 08:15:55 2025 -0700
feat: Support Array Concat in Python (#45)
This pull request enhances the `tvm_ffi.Array` class by introducing
support for array concatenation using the standard Python `+` operator.
This feature allows `Array` objects to be combined seamlessly with other
`Array` instances or Python sequences, making the `Array` class more
intuitive and consistent with Python's built-in sequence types.
Note that many standard `list`/`dict` operations are not yet support in
`tvm_ffi`. Leaving for future work...
---
python/tvm_ffi/container.py | 15 ++++++++--
tests/python/test_container.py | 67 ++++++++++++++++++++++++++++++++++++++++++
2 files changed, 79 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index e3b2fb7..f179fc9 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -18,9 +18,10 @@
from __future__ import annotations
+import itertools
import operator
from collections.abc import ItemsView as ItemsViewBase
-from collections.abc import Iterator, Mapping, Sequence
+from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import KeysView as KeysViewBase
from collections.abc import ValuesView as ValuesViewBase
from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload
@@ -89,7 +90,7 @@ class Array(core.Object, Sequence[T]):
Parameters
----------
- input_list : Sequence[T]
+ input_list : Iterable[T]
The list of values to be stored in the array.
See Also
@@ -108,7 +109,7 @@ class Array(core.Object, Sequence[T]):
"""
- def __init__(self, input_list: Sequence[T]) -> None:
+ def __init__(self, input_list: Iterable[T]) -> None:
"""Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
@@ -143,6 +144,14 @@ class Array(core.Object, Sequence[T]):
return type(self).__name__ + "(chandle=None)"
return "[" + ", ".join([x.__repr__() for x in self]) + "]"
+ def __add__(self, other: Iterable[T]) -> Array[T]:
+ """Concatenate two arrays."""
+ return type(self)(itertools.chain(self, other))
+
+ def __radd__(self, other: Iterable[T]) -> Array[T]:
+ """Concatenate two arrays."""
+ return type(self)(itertools.chain(other, self))
+
class KeysView(KeysViewBase[K]):
"""Helper class to return keys view."""
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index a025062..54b41b7 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pickle
+from collections.abc import Sequence
from typing import Any
import pytest
@@ -126,6 +127,72 @@ def test_serialization() -> None:
assert str(b) == "[1, 2, 3]"
[email protected](
+ "a, b, c_expected",
+ [
+ (
+ tvm_ffi.Array([1, 2, 3]),
+ tvm_ffi.Array([4, 5, 6]),
+ tvm_ffi.Array([1, 2, 3, 4, 5, 6]),
+ ),
+ (
+ tvm_ffi.Array([1, 2, 3]),
+ [4, 5, 6],
+ tvm_ffi.Array([1, 2, 3, 4, 5, 6]),
+ ),
+ (
+ [1, 2, 3],
+ tvm_ffi.Array([4, 5, 6]),
+ tvm_ffi.Array([1, 2, 3, 4, 5, 6]),
+ ),
+ (
+ tvm_ffi.Array([]),
+ tvm_ffi.Array([1, 2, 3]),
+ tvm_ffi.Array([1, 2, 3]),
+ ),
+ (
+ tvm_ffi.Array([1, 2, 3]),
+ [],
+ tvm_ffi.Array([1, 2, 3]),
+ ),
+ (
+ [],
+ tvm_ffi.Array([1, 2, 3]),
+ tvm_ffi.Array([1, 2, 3]),
+ ),
+ (
+ tvm_ffi.Array([]),
+ [],
+ tvm_ffi.Array([]),
+ ),
+ (
+ tvm_ffi.Array([]),
+ [],
+ tvm_ffi.Array([]),
+ ),
+ (
+ tvm_ffi.Array([1, 2, 3]),
+ (4, 5, 6),
+ tvm_ffi.Array([1, 2, 3, 4, 5, 6]),
+ ),
+ (
+ (1, 2, 3),
+ tvm_ffi.Array([4, 5, 6]),
+ tvm_ffi.Array([1, 2, 3, 4, 5, 6]),
+ ),
+ ],
+)
+def test_array_concat(
+ a: Sequence[int],
+ b: Sequence[int],
+ c_expected: Sequence[int],
+) -> None:
+ c_actual = a + b # type: ignore[operator]
+ assert type(c_actual) is type(c_expected)
+ assert len(c_actual) == len(c_expected)
+ assert tuple(c_actual) == tuple(c_expected)
+
+
def test_large_map_get() -> None:
amap = tvm_ffi.convert({k: k**2 for k in range(100)})
assert amap.get(101) is None