This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 2d68ed568 Add PrimitiveArray::new (#3879) (#3909)
2d68ed568 is described below
commit 2d68ed5686a2a41d1e486b4dc562c9a19db76c07
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Mar 23 12:24:55 2023 +0000
Add PrimitiveArray::new (#3879) (#3909)
* Add PrimitiveArray::new (#3879)
* Review feedback
* Format
---
arrow-arith/src/arity.rs | 26 ++------
arrow-array/src/array/primitive_array.rs | 101 +++++++++++++++++--------------
arrow-buffer/src/buffer/scalar.rs | 7 +++
3 files changed, 69 insertions(+), 65 deletions(-)
diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs
index 0a8815cc8..782c8270c 100644
--- a/arrow-arith/src/arity.rs
+++ b/arrow-arith/src/arity.rs
@@ -23,25 +23,10 @@ use arrow_array::types::ArrowDictionaryKeyType;
use arrow_array::*;
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::{Buffer, MutableBuffer};
-use arrow_data::{ArrayData, ArrayDataBuilder};
+use arrow_data::ArrayData;
use arrow_schema::ArrowError;
use std::sync::Arc;
-#[inline]
-unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
- len: usize,
- buffer: Buffer,
- nulls: Option<NullBuffer>,
-) -> PrimitiveArray<O> {
- PrimitiveArray::from(
- ArrayDataBuilder::new(O::DATA_TYPE)
- .len(len)
- .nulls(nulls)
- .buffers(vec![buffer])
- .build_unchecked(),
- )
-}
-
/// See [`PrimitiveArray::unary`]
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
@@ -209,7 +194,6 @@ where
"Cannot perform binary operation on arrays of different
length".to_string(),
));
}
- let len = a.len();
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
@@ -224,8 +208,7 @@ where
// Soundness
// `values` is an iterator with a known size from a PrimitiveArray
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
-
- Ok(unsafe { build_primitive_array(len, buffer, nulls) })
+ Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls))
}
/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in
`0..len`, mutating
@@ -328,7 +311,8 @@ where
Ok::<_, ArrowError>(())
})?;
- Ok(unsafe { build_primitive_array(len, buffer.finish(), Some(nulls)) })
+ let values = buffer.finish().into();
+ Ok(PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls)))
}
}
@@ -412,7 +396,7 @@ where
buffer.push_unchecked(op(a.value_unchecked(idx),
b.value_unchecked(idx))?);
};
}
- Ok(unsafe { build_primitive_array(len, buffer.into(), None) })
+ Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), None))
}
/// This intentional inline(never) attribute helps LLVM optimize the loop.
diff --git a/arrow-array/src/array/primitive_array.rs
b/arrow-array/src/array/primitive_array.rs
index 241e2a051..6faecb1f0 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -29,7 +29,7 @@ use arrow_buffer::{
i256, ArrowNativeType, BooleanBuffer, Buffer, NullBuffer, ScalarBuffer,
};
use arrow_data::bit_iterator::try_for_each_valid_idx;
-use arrow_data::{ArrayData, ArrayDataBuilder};
+use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType};
use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime};
use half::f16;
@@ -251,19 +251,58 @@ pub struct PrimitiveArray<T: ArrowPrimitiveType> {
/// Underlying ArrayData
data: ArrayData,
/// Values data
- raw_values: ScalarBuffer<T::Native>,
+ values: ScalarBuffer<T::Native>,
}
impl<T: ArrowPrimitiveType> Clone for PrimitiveArray<T> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
- raw_values: self.raw_values.clone(),
+ values: self.values.clone(),
}
}
}
impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
+ /// Create a new [`PrimitiveArray`] from the provided data_type, values,
nulls
+ ///
+ /// # Panics
+ ///
+ /// Panics if:
+ /// - `values.len() != nulls.len()`
+ /// - `!Self::is_compatible(data_type)`
+ pub fn new(
+ data_type: DataType,
+ values: ScalarBuffer<T::Native>,
+ nulls: Option<NullBuffer>,
+ ) -> Self {
+ Self::assert_compatible(&data_type);
+ if let Some(n) = nulls.as_ref() {
+ assert_eq!(values.len(), n.len());
+ }
+
+ // TODO: Don't store ArrayData inside arrays (#3880)
+ let data = unsafe {
+ ArrayData::builder(data_type)
+ .len(values.len())
+ .nulls(nulls)
+ .buffers(vec![values.inner().clone()])
+ .build_unchecked()
+ };
+
+ Self { data, values }
+ }
+
+ /// Asserts that `data_type` is compatible with `Self`
+ fn assert_compatible(data_type: &DataType) {
+ assert!(
+ Self::is_compatible(data_type),
+ "PrimitiveArray expected data type {} got {}",
+ T::DATA_TYPE,
+ data_type
+ );
+ }
+
/// Returns the length of this array.
#[inline]
pub fn len(&self) -> usize {
@@ -278,7 +317,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// Returns the values of this array
#[inline]
pub fn values(&self) -> &ScalarBuffer<T::Native> {
- &self.raw_values
+ &self.values
}
/// Returns a new primitive array builder
@@ -308,7 +347,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// caller must ensure that the passed in offset is less than the array
len()
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> T::Native {
- *self.raw_values.get_unchecked(i)
+ *self.values.get_unchecked(i)
}
/// Returns the primitive value at index `i`.
@@ -346,7 +385,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
pub fn from_value(value: T::Native, count: usize) -> Self {
unsafe {
let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_|
value));
- build_primitive_array(count, val_buf, None)
+ Self::new(T::DATA_TYPE, val_buf.into(), None)
}
}
@@ -422,7 +461,6 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
F: Fn(T::Native) -> O::Native,
{
let data = self.data();
- let len = self.len();
let nulls = data.nulls().cloned();
let values = self.values().iter().map(|v| op(*v));
@@ -432,7 +470,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
// Soundness
// `values` is an iterator with a known size because arrays are
sized.
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
- unsafe { build_primitive_array(len, buffer, nulls) }
+ PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls)
}
/// Applies an unary and infallible function to a mutable primitive array.
@@ -495,7 +533,8 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
None => (0..len).try_for_each(f)?,
}
- Ok(unsafe { build_primitive_array(len, buffer.finish(), nulls) })
+ let values = buffer.finish().into();
+ Ok(PrimitiveArray::new(O::DATA_TYPE, values, nulls))
}
/// Applies an unary and fallible function to all valid values in a
mutable primitive array.
@@ -579,13 +618,9 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
});
let nulls = BooleanBuffer::new(null_builder.finish(), 0, len);
- unsafe {
- build_primitive_array(
- len,
- buffer.finish(),
- Some(NullBuffer::new_unchecked(nulls, out_null_count)),
- )
- }
+ let values = buffer.finish().into();
+ let nulls = unsafe { NullBuffer::new_unchecked(nulls, out_null_count)
};
+ PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls))
}
/// Returns `PrimitiveBuilder` of this primitive array for mutating its
values if the underlying
@@ -599,7 +634,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
.slice_with_length(self.data.offset() * element_len, len *
element_len);
drop(self.data);
- drop(self.raw_values);
+ drop(self.values);
let try_mutable_null_buffer = match null_bit_buffer {
None => Ok(None),
@@ -647,21 +682,6 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
}
}
-#[inline]
-unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
- len: usize,
- buffer: Buffer,
- nulls: Option<NullBuffer>,
-) -> PrimitiveArray<O> {
- PrimitiveArray::from(
- ArrayDataBuilder::new(O::DATA_TYPE)
- .len(len)
- .buffers(vec![buffer])
- .nulls(nulls)
- .build_unchecked(),
- )
-}
-
impl<T: ArrowPrimitiveType> From<PrimitiveArray<T>> for ArrayData {
fn from(array: PrimitiveArray<T>) -> Self {
array.data
@@ -1052,21 +1072,16 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
/// Constructs a `PrimitiveArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
fn from(data: ArrayData) -> Self {
- assert!(
- Self::is_compatible(data.data_type()),
- "PrimitiveArray expected ArrayData with type {} got {}",
- T::DATA_TYPE,
- data.data_type()
- );
+ Self::assert_compatible(data.data_type());
assert_eq!(
data.buffers().len(),
1,
"PrimitiveArray data should contain a single buffer only (values
buffer)"
);
- let raw_values =
+ let values =
ScalarBuffer::new(data.buffers()[0].clone(), data.offset(),
data.len());
- Self { data, raw_values }
+ Self { data, values }
}
}
@@ -1833,9 +1848,7 @@ mod tests {
}
#[test]
- #[should_panic(
- expected = "PrimitiveArray expected ArrayData with type Int64 got
Int32"
- )]
+ #[should_panic(expected = "PrimitiveArray expected data type Int64 got
Int32")]
fn test_from_array_data_validation() {
let foo = PrimitiveArray::<Int32Type>::from_iter([1, 2, 3]);
let _ = PrimitiveArray::<Int64Type>::from(foo.into_data());
@@ -2211,7 +2224,7 @@ mod tests {
#[test]
#[should_panic(
- expected = "PrimitiveArray expected ArrayData with type
Interval(MonthDayNano) got Interval(DayTime)"
+ expected = "PrimitiveArray expected data type Interval(MonthDayNano)
got Interval(DayTime)"
)]
fn test_invalid_interval_type() {
let array = IntervalDayTimeArray::from(vec![1, 2, 3]);
diff --git a/arrow-buffer/src/buffer/scalar.rs
b/arrow-buffer/src/buffer/scalar.rs
index 4c16a736b..1a4680111 100644
--- a/arrow-buffer/src/buffer/scalar.rs
+++ b/arrow-buffer/src/buffer/scalar.rs
@@ -17,6 +17,7 @@
use crate::buffer::Buffer;
use crate::native::ArrowNativeType;
+use crate::MutableBuffer;
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::ops::Deref;
@@ -96,6 +97,12 @@ impl<T: ArrowNativeType> AsRef<[T]> for ScalarBuffer<T> {
}
}
+impl<T: ArrowNativeType> From<MutableBuffer> for ScalarBuffer<T> {
+ fn from(value: MutableBuffer) -> Self {
+ Buffer::from(value).into()
+ }
+}
+
impl<T: ArrowNativeType> From<Buffer> for ScalarBuffer<T> {
fn from(buffer: Buffer) -> Self {
let align = std::mem::align_of::<T>();