This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new d0d0e2f  feat: add Rust binding for `Array<T>` (#348)
d0d0e2f is described below

commit d0d0e2f935cda443bd85e097a3cfb18de2a96f4d
Author: Haejoon Kim <[email protected]>
AuthorDate: Fri Jan 30 22:11:57 2026 +0900

    feat: add Rust binding for `Array<T>` (#348)
    
    This PR introduces a Rust implementation of the `Array` container.
    
    ### Key Features
    * **Memory Safety**: Correctly handles reference counting for
    `ObjectRef` elements. It ensures `inc_ref` is called during retrieval
    (`get`) and `dec_ref` is called when elements are removed or the array
    is cleared.
    * **Dynamic Mutation**: Implements `push`, `pop`, `insert`, `remove`,
    and `clear`. It handles internal growth and reallocation while
    maintaining compatibility with the underlying C++ memory layout.
    * **FFI Compatibility**: Explicitly manages the `data` pointer within
    `ArrayObj` to allow C++ TVM functions to traverse the array using
    standard pointer arithmetic.
    * **Type System Integration**:
    * Implements `AnyCompatible`, allowing `Array<T>` to be erased into
    `Any` and `AnyView` and recovered via `TryFrom`.
    * Implements `FromIterator` and `Extend`, enabling seamless integration
    with Rust's iterator ecosystem.
    
    ## Tests
    Verified with a comprehensive test suite in
    `tvm-ffi/tests/test_array.rs` covering:
    - [x] Basic creation and iteration.
    - [x] Out-of-bounds safety.
    - [x] Dynamic growth and reallocation (push/insert).
    - [x] Memory integrity after element shifting (remove/insert).
    - [x] Roundtrip conversions through `Any` and `AnyView`.
    - [x] Parametric support for both `Tensor` and `Shape` types.
---
 rust/tvm-ffi/src/any.rs               |   1 +
 rust/tvm-ffi/src/collections/array.rs | 341 ++++++++++++++++++++++++++++++++++
 rust/tvm-ffi/src/collections/mod.rs   |   3 +-
 rust/tvm-ffi/src/lib.rs               |   1 +
 rust/tvm-ffi/tests/test_array.rs      | 132 +++++++++++++
 5 files changed, 477 insertions(+), 1 deletion(-)

diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs
index d5a4c85..ecf8b9e 100644
--- a/rust/tvm-ffi/src/any.rs
+++ b/rust/tvm-ffi/src/any.rs
@@ -177,6 +177,7 @@ impl Any {
 
     #[inline]
     pub unsafe fn into_raw_ffi_any(this: Self) -> TVMFFIAny {
+        let this = std::mem::ManuallyDrop::new(this);
         this.data
     }
 
diff --git a/rust/tvm-ffi/src/collections/array.rs 
b/rust/tvm-ffi/src/collections/array.rs
new file mode 100644
index 0000000..6f259ba
--- /dev/null
+++ b/rust/tvm-ffi/src/collections/array.rs
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use std::fmt::Debug;
+use std::marker::PhantomData;
+use std::ops::Deref;
+
+use crate::any::TryFromTemp;
+use crate::derive::Object;
+use crate::object::{Object, ObjectArc};
+use crate::{Any, AnyCompatible, AnyView, ObjectCoreWithExtraItems, 
ObjectRefCore};
+use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
+use tvm_ffi_sys::{TVMFFIAny, TVMFFIObject};
+
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "ffi.Array"]
+#[type_index(TypeIndex::kTVMFFIArray)]
+pub struct ArrayObj {
+    pub object: Object,
+    /// Pointer to the start of the element buffer (AddressOf(0)).
+    pub data: *mut core::ffi::c_void,
+    pub size: i64,
+    pub capacity: i64,
+    /// Optional custom deleter for the data pointer.
+    pub data_deleter: Option<unsafe extern "C" fn(*mut core::ffi::c_void)>,
+}
+
+unsafe impl ObjectCoreWithExtraItems for ArrayObj {
+    type ExtraItem = TVMFFIAny;
+    fn extra_items_count(this: &Self) -> usize {
+        this.size as usize
+    }
+}
+
+#[repr(C)]
+#[derive(Clone)]
+pub struct Array<T: AnyCompatible + Clone> {
+    data: ObjectArc<ArrayObj>,
+    _marker: PhantomData<T>,
+}
+
+impl<T: AnyCompatible + Clone> Debug for Array<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        let full_name = std::any::type_name::<T>();
+        let short_name = full_name.split("::").last().unwrap_or(full_name);
+        write!(f, "Array<{}>[{}]", short_name, self.len())
+    }
+}
+
+impl<T: AnyCompatible + Clone> Default for Array<T> {
+    fn default() -> Self {
+        Self::new(vec![])
+    }
+}
+
+unsafe impl<T: AnyCompatible + Clone> ObjectRefCore for Array<T> {
+    type ContainerType = ArrayObj;
+
+    fn data(this: &Self) -> &ObjectArc<Self::ContainerType> {
+        &this.data
+    }
+
+    fn into_data(this: Self) -> ObjectArc<Self::ContainerType> {
+        this.data
+    }
+
+    fn from_data(data: ObjectArc<Self::ContainerType>) -> Self {
+        Self {
+            data,
+            _marker: PhantomData,
+        }
+    }
+}
+
+impl<T: AnyCompatible + Clone> Array<T> {
+    /// Creates a new Array from a vector of items.
+    pub fn new(items: Vec<T>) -> Self {
+        let capacity = items.len();
+        Self::new_with_capacity(items, capacity)
+    }
+
+    /// Internal helper to allocate an ArrayObj with specific headroom.
+    fn new_with_capacity(items: Vec<T>, capacity: usize) -> Self {
+        let size = items.len();
+
+        // Allocate with capacity
+        let arc = ObjectArc::<ArrayObj>::new_with_extra_items(ArrayObj {
+            object: Object::new(),
+            data: core::ptr::null_mut(),
+            size: size as i64,
+            capacity: capacity as i64,
+            data_deleter: None,
+        });
+
+        unsafe {
+            let raw_ptr = ObjectArc::as_raw(&arc) as *mut ArrayObj;
+            let container = &mut *raw_ptr;
+
+            let base_ptr = ArrayObj::extra_items_mut(container).as_ptr() as 
*mut TVMFFIAny;
+            container.data = base_ptr as *mut _;
+
+            for (i, item) in items.into_iter().enumerate() {
+                let any: Any = Any::from(item);
+                let raw = Any::into_raw_ffi_any(any);
+                core::ptr::write(base_ptr.add(i), raw);
+            }
+        }
+        Self::from_data(arc)
+    }
+
+    pub fn len(&self) -> usize {
+        self.data.size as usize
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.len() == 0
+    }
+
+    /// Retrieves an item at the given index.
+    pub fn get(&self, index: usize) -> Result<T, crate::Error> {
+        if index >= self.len() {
+            crate::bail!(crate::error::INDEX_ERROR, "Array get index out of 
bound");
+        }
+        unsafe {
+            let container = self.data.deref();
+            let base_ptr = container.data as *const TVMFFIAny;
+            let raw_any_ref = &*base_ptr.add(index);
+
+            match T::try_cast_from_any_view(raw_any_ref) {
+                Ok(val) => Ok(val),
+                Err(_) => crate::bail!(
+                    crate::error::TYPE_ERROR,
+                    "Failed to cast element at {} to {}",
+                    index,
+                    T::type_str()
+                ),
+            }
+        }
+    }
+
+    pub fn iter(&'_ self) -> ArrayIterator<'_, T> {
+        ArrayIterator {
+            array: self,
+            index: 0,
+            len: self.len(),
+        }
+    }
+
+    #[inline]
+    fn as_container(&self) -> &ArrayObj {
+        unsafe {
+            let ptr = ObjectArc::as_raw(&self.data) as *const ArrayObj;
+            &*ptr
+        }
+    }
+}
+
+// --- Index Implementation ---
+
+impl<T: AnyCompatible + Clone> std::ops::Index<usize> for Array<T> {
+    type Output = AnyView<'static>;
+
+    fn index(&self, index: usize) -> &Self::Output {
+        let container = self.as_container();
+        let len = container.size as usize;
+        if index >= len {
+            panic!(
+                "Index out of bounds: the len is {} but the index is {}",
+                len, index
+            );
+        }
+        unsafe {
+            let ptr = (container.data as *const AnyView<'static>).add(index);
+            &*ptr
+        }
+    }
+}
+
+// --- Iterator Implementations ---
+
+pub struct ArrayIterator<'a, T: AnyCompatible + Clone> {
+    array: &'a Array<T>,
+    index: usize,
+    len: usize,
+}
+
+impl<'a, T: AnyCompatible + Clone> Iterator for ArrayIterator<'a, T> {
+    type Item = T;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.index < self.len {
+            let item = self.array.get(self.index).ok();
+            self.index += 1;
+            item
+        } else {
+            None
+        }
+    }
+}
+
+impl<'a, T: AnyCompatible + Clone> IntoIterator for &'a Array<T> {
+    type Item = T;
+    type IntoIter = ArrayIterator<'a, T>;
+
+    fn into_iter(self) -> Self::IntoIter {
+        self.iter()
+    }
+}
+
+impl<T: AnyCompatible + Clone> FromIterator<T> for Array<T> {
+    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
+        let items: Vec<T> = iter.into_iter().collect();
+        Self::new(items)
+    }
+}
+
+// --- Any Type System Conversions ---
+
+unsafe impl<T> AnyCompatible for Array<T>
+where
+    T: AnyCompatible + Clone + 'static,
+{
+    fn type_str() -> String {
+        format!("Array<{}>", T::type_str())
+    }
+
+    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
+        if data.type_index != TypeIndex::kTVMFFIArray as i32 {
+            return false;
+        }
+
+        if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Any>() {
+            return true;
+        }
+
+        let container = &*(data.data_union.v_obj as *const ArrayObj);
+        let base_ptr = container.data as *const TVMFFIAny;
+        for i in 0..container.size {
+            let elem_any = &*base_ptr.add(i as usize);
+            if !T::check_any_strict(elem_any) {
+                return false;
+            }
+        }
+        true
+    }
+
+    unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
+        data.type_index = TypeIndex::kTVMFFIArray as i32;
+        data.data_union.v_obj = ObjectArc::as_raw(Self::data(src)) as *mut 
TVMFFIObject;
+        data.small_str_len = 0;
+    }
+
+    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
+        data.type_index = TypeIndex::kTVMFFIArray as i32;
+        data.data_union.v_obj = ObjectArc::into_raw(Self::into_data(src)) as 
*mut TVMFFIObject;
+        data.small_str_len = 0;
+    }
+
+    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
+        let ptr = data.data_union.v_obj as *const ArrayObj;
+        crate::object::unsafe_::inc_ref(ptr as *mut TVMFFIObject);
+        Self::from_data(ObjectArc::from_raw(ptr))
+    }
+
+    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
+        let ptr = data.data_union.v_obj as *const ArrayObj;
+        let obj = Self::from_data(ObjectArc::from_raw(ptr));
+
+        data.type_index = TypeIndex::kTVMFFINone as i32;
+        data.data_union.v_int64 = 0;
+
+        obj
+    }
+
+    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
+        if data.type_index != TypeIndex::kTVMFFIArray as i32 {
+            return Err(());
+        }
+
+        // Fast path: if types match exactly, we can just copy the reference.
+        if Self::check_any_strict(data) {
+            return Ok(Self::copy_from_any_view_after_check(data));
+        }
+
+        // Slow path: try to convert element by element.
+        let container = &*(data.data_union.v_obj as *const ArrayObj);
+        let base_ptr = container.data as *const TVMFFIAny;
+        let mut items = Vec::with_capacity(container.size as usize);
+
+        for i in 0..container.size {
+            let any_v = &*base_ptr.add(i as usize);
+            if let Ok(item) = T::try_cast_from_any_view(any_v) {
+                items.push(item);
+            } else {
+                return Err(());
+            }
+        }
+
+        Ok(Array::new(items))
+    }
+}
+
+impl<T> TryFrom<Any> for Array<T>
+where
+    T: AnyCompatible + Clone + 'static,
+{
+    type Error = crate::error::Error;
+
+    fn try_from(value: Any) -> Result<Self, Self::Error> {
+        let temp: TryFromTemp<Self> = TryFromTemp::try_from(value)?;
+        Ok(TryFromTemp::into_value(temp))
+    }
+}
+
+impl<'a, T> TryFrom<AnyView<'a>> for Array<T>
+where
+    T: AnyCompatible + Clone + 'static,
+{
+    type Error = crate::error::Error;
+
+    fn try_from(value: AnyView<'a>) -> Result<Self, Self::Error> {
+        let temp: TryFromTemp<Self> = TryFromTemp::try_from(value)?;
+        Ok(TryFromTemp::into_value(temp))
+    }
+}
diff --git a/rust/tvm-ffi/src/collections/mod.rs 
b/rust/tvm-ffi/src/collections/mod.rs
index 85635a7..ad17dcc 100644
--- a/rust/tvm-ffi/src/collections/mod.rs
+++ b/rust/tvm-ffi/src/collections/mod.rs
@@ -16,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-pub mod shape;
 /// Collection types
+pub mod array;
+pub mod shape;
 pub mod tensor;
diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs
index 94e87b0..fad8260 100644
--- a/rust/tvm-ffi/src/lib.rs
+++ b/rust/tvm-ffi/src/lib.rs
@@ -32,6 +32,7 @@ pub mod type_traits;
 pub use tvm_ffi_sys;
 
 pub use crate::any::{Any, AnyView};
+pub use crate::collections::array::Array;
 pub use crate::collections::shape::Shape;
 pub use crate::collections::tensor::{CPUNDAlloc, NDAllocator, Tensor};
 pub use crate::device::{current_stream, with_stream};
diff --git a/rust/tvm-ffi/tests/test_array.rs b/rust/tvm-ffi/tests/test_array.rs
new file mode 100644
index 0000000..fe87c5f
--- /dev/null
+++ b/rust/tvm-ffi/tests/test_array.rs
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use tvm_ffi::*;
+
+/// Helper to create a Tensor with a specific float value and shape
+fn create_tensor(val: f32, shape: &[i64]) -> Tensor {
+    let dtype = DLDataType::new(DLDataTypeCode::kDLFloat, 32, 1);
+    let device = DLDevice::new(DLDeviceType::kDLCPU, 0);
+    let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device);
+    if let Ok(slice) = tensor.data_as_slice_mut::<f32>() {
+        slice[0] = val;
+    }
+    tensor
+}
+
+/// Helper to extract the first float value from a Tensor
+fn get_val(tensor: &Tensor) -> f32 {
+    tensor
+        .data_as_slice::<f32>()
+        .expect("Type mismatch or null")[0]
+}
+
+#[test]
+fn test_array_core_and_iteration() {
+    let t1 = create_tensor(10.0, &[1, 2]);
+    let t2 = create_tensor(20.0, &[3, 4, 5]);
+
+    let array = Array::new(vec![t1.clone(), t2.clone()]);
+
+    // Core Accessors
+    assert_eq!(array.len(), 2);
+    assert!(!array.is_empty());
+
+    // Value Integrity
+    assert_eq!(get_val(&Tensor::try_from(array[0]).unwrap()), 10.0);
+    assert_eq!(Tensor::try_from(array[0]).unwrap().ndim(), 2);
+    assert_eq!(Tensor::try_from(array[1]).unwrap().ndim(), 3);
+
+    // Iteration
+    let vals: Vec<f32> = array.iter().map(|t| get_val(&t)).collect();
+    assert_eq!(vals, vec![10.0, 20.0]);
+}
+
+#[test]
+fn test_array_any_conversions() {
+    let array = Array::new(vec![
+        create_tensor(1.0, &[1]),
+        create_tensor(2.0, &[1]),
+        create_tensor(3.0, &[1]),
+    ]);
+
+    // Test Any/AnyView Roundtrip (Verifies AnyCompatible and Trait Bounds)
+    let any = Any::from(array);
+    assert_eq!(any.type_index(), TypeIndex::kTVMFFIArray as i32);
+
+    let back: Array<Tensor> = Array::try_from(any).expect("Any -> Array 
failed");
+    assert_eq!(back.len(), 3);
+    assert_eq!(get_val(&back.get(2).unwrap()), 3.0);
+
+    let view = AnyView::from(&back);
+    let back_from_view: Array<Tensor> = Array::try_from(view).expect("AnyView 
-> Array failed");
+    assert_eq!(back_from_view.len(), 3);
+}
+
+#[test]
+fn test_array_recursive_type_checking() {
+    // 1. Create an Array of Shapes
+    let shape_array = Array::new(vec![Shape::from(vec![1, 2]), 
Shape::from(vec![3])]);
+
+    // 2. Wrap it in Any
+    let any_val = Any::from(shape_array);
+
+    // 3. Try to convert Any (containing Shapes) into Array<Tensor>
+    // This should FAIL because T::check_any_strict (Tensor) will fail on 
Shape elements
+    let tensor_cast = Array::<Tensor>::try_from(any_val.clone());
+    assert!(
+        tensor_cast.is_err(),
+        "Should not be able to cast Array<Shape> to Array<Tensor>"
+    );
+
+    // 4. Verify valid cast works
+    let shape_cast = Array::<Shape>::try_from(any_val);
+    assert!(
+        shape_cast.is_ok(),
+        "Should be able to cast back to correct type"
+    );
+}
+
+#[test]
+fn test_array_parametric_heterogeneity() {
+    // Verify Array works with different ObjectRefCore types
+    let shape_array = Array::new(vec![Shape::from(vec![1, 2, 3]), 
Shape::from(vec![10])]);
+    assert_eq!(shape_array.get(0).unwrap().as_slice(), &[1, 2, 3]);
+    assert_eq!(shape_array.get(1).unwrap().as_slice(), &[10]);
+
+    let function_array = Array::new(vec![
+        Function::get_global("ffi.String").unwrap(),
+        Function::get_global("ffi.Bytes").unwrap(),
+    ]);
+    assert_eq!(
+        into_typed_fn!(
+            function_array.get(0).unwrap(),
+            Fn(String) -> Result<String>
+        )("hello".into())
+        .unwrap(),
+        "hello"
+    );
+    assert_eq!(
+        into_typed_fn!(
+            function_array.get(1).unwrap(),
+            Fn(Bytes) -> Result<Bytes>
+        )([1, 2, 3].into())
+        .unwrap(),
+        &[1, 2, 3]
+    );
+}

Reply via email to