This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d68ed568 Add PrimitiveArray::new (#3879) (#3909)
2d68ed568 is described below

commit 2d68ed5686a2a41d1e486b4dc562c9a19db76c07
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Mar 23 12:24:55 2023 +0000

    Add PrimitiveArray::new (#3879) (#3909)
    
    * Add PrimitiveArray::new (#3879)
    
    * Review feedback
    
    * Format
---
 arrow-arith/src/arity.rs                 |  26 ++------
 arrow-array/src/array/primitive_array.rs | 101 +++++++++++++++++--------------
 arrow-buffer/src/buffer/scalar.rs        |   7 +++
 3 files changed, 69 insertions(+), 65 deletions(-)

diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs
index 0a8815cc8..782c8270c 100644
--- a/arrow-arith/src/arity.rs
+++ b/arrow-arith/src/arity.rs
@@ -23,25 +23,10 @@ use arrow_array::types::ArrowDictionaryKeyType;
 use arrow_array::*;
 use arrow_buffer::buffer::NullBuffer;
 use arrow_buffer::{Buffer, MutableBuffer};
-use arrow_data::{ArrayData, ArrayDataBuilder};
+use arrow_data::ArrayData;
 use arrow_schema::ArrowError;
 use std::sync::Arc;
 
-#[inline]
-unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
-    len: usize,
-    buffer: Buffer,
-    nulls: Option<NullBuffer>,
-) -> PrimitiveArray<O> {
-    PrimitiveArray::from(
-        ArrayDataBuilder::new(O::DATA_TYPE)
-            .len(len)
-            .nulls(nulls)
-            .buffers(vec![buffer])
-            .build_unchecked(),
-    )
-}
-
 /// See [`PrimitiveArray::unary`]
 pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
 where
@@ -209,7 +194,6 @@ where
             "Cannot perform binary operation on arrays of different 
length".to_string(),
         ));
     }
-    let len = a.len();
 
     if a.is_empty() {
         return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
@@ -224,8 +208,7 @@ where
     //  Soundness
     //      `values` is an iterator with a known size from a PrimitiveArray
     let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
-
-    Ok(unsafe { build_primitive_array(len, buffer, nulls) })
+    Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls))
 }
 
 /// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in 
`0..len`, mutating
@@ -328,7 +311,8 @@ where
             Ok::<_, ArrowError>(())
         })?;
 
-        Ok(unsafe { build_primitive_array(len, buffer.finish(), Some(nulls)) })
+        let values = buffer.finish().into();
+        Ok(PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls)))
     }
 }
 
@@ -412,7 +396,7 @@ where
             buffer.push_unchecked(op(a.value_unchecked(idx), 
b.value_unchecked(idx))?);
         };
     }
-    Ok(unsafe { build_primitive_array(len, buffer.into(), None) })
+    Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), None))
 }
 
 /// This intentional inline(never) attribute helps LLVM optimize the loop.
diff --git a/arrow-array/src/array/primitive_array.rs 
b/arrow-array/src/array/primitive_array.rs
index 241e2a051..6faecb1f0 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -29,7 +29,7 @@ use arrow_buffer::{
     i256, ArrowNativeType, BooleanBuffer, Buffer, NullBuffer, ScalarBuffer,
 };
 use arrow_data::bit_iterator::try_for_each_valid_idx;
-use arrow_data::{ArrayData, ArrayDataBuilder};
+use arrow_data::ArrayData;
 use arrow_schema::{ArrowError, DataType};
 use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime};
 use half::f16;
@@ -251,19 +251,58 @@ pub struct PrimitiveArray<T: ArrowPrimitiveType> {
     /// Underlying ArrayData
     data: ArrayData,
     /// Values data
-    raw_values: ScalarBuffer<T::Native>,
+    values: ScalarBuffer<T::Native>,
 }
 
 impl<T: ArrowPrimitiveType> Clone for PrimitiveArray<T> {
     fn clone(&self) -> Self {
         Self {
             data: self.data.clone(),
-            raw_values: self.raw_values.clone(),
+            values: self.values.clone(),
         }
     }
 }
 
 impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
+    /// Create a new [`PrimitiveArray`] from the provided data_type, values, 
nulls
+    ///
+    /// # Panics
+    ///
+    /// Panics if:
+    /// - `values.len() != nulls.len()`
+    /// - `!Self::is_compatible(data_type)`
+    pub fn new(
+        data_type: DataType,
+        values: ScalarBuffer<T::Native>,
+        nulls: Option<NullBuffer>,
+    ) -> Self {
+        Self::assert_compatible(&data_type);
+        if let Some(n) = nulls.as_ref() {
+            assert_eq!(values.len(), n.len());
+        }
+
+        // TODO: Don't store ArrayData inside arrays (#3880)
+        let data = unsafe {
+            ArrayData::builder(data_type)
+                .len(values.len())
+                .nulls(nulls)
+                .buffers(vec![values.inner().clone()])
+                .build_unchecked()
+        };
+
+        Self { data, values }
+    }
+
+    /// Asserts that `data_type` is compatible with `Self`
+    fn assert_compatible(data_type: &DataType) {
+        assert!(
+            Self::is_compatible(data_type),
+            "PrimitiveArray expected data type {} got {}",
+            T::DATA_TYPE,
+            data_type
+        );
+    }
+
     /// Returns the length of this array.
     #[inline]
     pub fn len(&self) -> usize {
@@ -278,7 +317,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
     /// Returns the values of this array
     #[inline]
     pub fn values(&self) -> &ScalarBuffer<T::Native> {
-        &self.raw_values
+        &self.values
     }
 
     /// Returns a new primitive array builder
@@ -308,7 +347,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
     /// caller must ensure that the passed in offset is less than the array 
len()
     #[inline]
     pub unsafe fn value_unchecked(&self, i: usize) -> T::Native {
-        *self.raw_values.get_unchecked(i)
+        *self.values.get_unchecked(i)
     }
 
     /// Returns the primitive value at index `i`.
@@ -346,7 +385,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
     pub fn from_value(value: T::Native, count: usize) -> Self {
         unsafe {
             let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| 
value));
-            build_primitive_array(count, val_buf, None)
+            Self::new(T::DATA_TYPE, val_buf.into(), None)
         }
     }
 
@@ -422,7 +461,6 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
         F: Fn(T::Native) -> O::Native,
     {
         let data = self.data();
-        let len = self.len();
 
         let nulls = data.nulls().cloned();
         let values = self.values().iter().map(|v| op(*v));
@@ -432,7 +470,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
         //  Soundness
         //      `values` is an iterator with a known size because arrays are 
sized.
         let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
-        unsafe { build_primitive_array(len, buffer, nulls) }
+        PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls)
     }
 
     /// Applies an unary and infallible function to a mutable primitive array.
@@ -495,7 +533,8 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
             None => (0..len).try_for_each(f)?,
         }
 
-        Ok(unsafe { build_primitive_array(len, buffer.finish(), nulls) })
+        let values = buffer.finish().into();
+        Ok(PrimitiveArray::new(O::DATA_TYPE, values, nulls))
     }
 
     /// Applies an unary and fallible function to all valid values in a 
mutable primitive array.
@@ -579,13 +618,9 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
         });
 
         let nulls = BooleanBuffer::new(null_builder.finish(), 0, len);
-        unsafe {
-            build_primitive_array(
-                len,
-                buffer.finish(),
-                Some(NullBuffer::new_unchecked(nulls, out_null_count)),
-            )
-        }
+        let values = buffer.finish().into();
+        let nulls = unsafe { NullBuffer::new_unchecked(nulls, out_null_count) 
};
+        PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls))
     }
 
     /// Returns `PrimitiveBuilder` of this primitive array for mutating its 
values if the underlying
@@ -599,7 +634,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
             .slice_with_length(self.data.offset() * element_len, len * 
element_len);
 
         drop(self.data);
-        drop(self.raw_values);
+        drop(self.values);
 
         let try_mutable_null_buffer = match null_bit_buffer {
             None => Ok(None),
@@ -647,21 +682,6 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
     }
 }
 
-#[inline]
-unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
-    len: usize,
-    buffer: Buffer,
-    nulls: Option<NullBuffer>,
-) -> PrimitiveArray<O> {
-    PrimitiveArray::from(
-        ArrayDataBuilder::new(O::DATA_TYPE)
-            .len(len)
-            .buffers(vec![buffer])
-            .nulls(nulls)
-            .build_unchecked(),
-    )
-}
-
 impl<T: ArrowPrimitiveType> From<PrimitiveArray<T>> for ArrayData {
     fn from(array: PrimitiveArray<T>) -> Self {
         array.data
@@ -1052,21 +1072,16 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
 /// Constructs a `PrimitiveArray` from an array data reference.
 impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
     fn from(data: ArrayData) -> Self {
-        assert!(
-            Self::is_compatible(data.data_type()),
-            "PrimitiveArray expected ArrayData with type {} got {}",
-            T::DATA_TYPE,
-            data.data_type()
-        );
+        Self::assert_compatible(data.data_type());
         assert_eq!(
             data.buffers().len(),
             1,
             "PrimitiveArray data should contain a single buffer only (values 
buffer)"
         );
 
-        let raw_values =
+        let values =
             ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), 
data.len());
-        Self { data, raw_values }
+        Self { data, values }
     }
 }
 
@@ -1833,9 +1848,7 @@ mod tests {
     }
 
     #[test]
-    #[should_panic(
-        expected = "PrimitiveArray expected ArrayData with type Int64 got 
Int32"
-    )]
+    #[should_panic(expected = "PrimitiveArray expected data type Int64 got 
Int32")]
     fn test_from_array_data_validation() {
         let foo = PrimitiveArray::<Int32Type>::from_iter([1, 2, 3]);
         let _ = PrimitiveArray::<Int64Type>::from(foo.into_data());
@@ -2211,7 +2224,7 @@ mod tests {
 
     #[test]
     #[should_panic(
-        expected = "PrimitiveArray expected ArrayData with type 
Interval(MonthDayNano) got Interval(DayTime)"
+        expected = "PrimitiveArray expected data type Interval(MonthDayNano) 
got Interval(DayTime)"
     )]
     fn test_invalid_interval_type() {
         let array = IntervalDayTimeArray::from(vec![1, 2, 3]);
diff --git a/arrow-buffer/src/buffer/scalar.rs 
b/arrow-buffer/src/buffer/scalar.rs
index 4c16a736b..1a4680111 100644
--- a/arrow-buffer/src/buffer/scalar.rs
+++ b/arrow-buffer/src/buffer/scalar.rs
@@ -17,6 +17,7 @@
 
 use crate::buffer::Buffer;
 use crate::native::ArrowNativeType;
+use crate::MutableBuffer;
 use std::fmt::Formatter;
 use std::marker::PhantomData;
 use std::ops::Deref;
@@ -96,6 +97,12 @@ impl<T: ArrowNativeType> AsRef<[T]> for ScalarBuffer<T> {
     }
 }
 
+impl<T: ArrowNativeType> From<MutableBuffer> for ScalarBuffer<T> {
+    fn from(value: MutableBuffer) -> Self {
+        Buffer::from(value).into()
+    }
+}
+
 impl<T: ArrowNativeType> From<Buffer> for ScalarBuffer<T> {
     fn from(buffer: Buffer) -> Self {
         let align = std::mem::align_of::<T>();

Reply via email to