This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new ee2c29236 Add Datum based arithmetic kernels (#3999) (#4465)
ee2c29236 is described below
commit ee2c29236077094724a8031c17af2562e96dbd07
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sat Jul 8 14:05:26 2023 -0400
Add Datum based arithmetic kernels (#3999) (#4465)
* Add Datum based arithmetic kernels (#3999)
* Fix benchmark
* Review feedback
---
arrow-arith/src/aggregate.rs | 39 +-
arrow-arith/src/arithmetic.rs | 766 ++++---------------------------
arrow-arith/src/lib.rs | 2 +
arrow-arith/src/numeric.rs | 672 +++++++++++++++++++++++++++
arrow-array/src/array/primitive_array.rs | 9 +
arrow-array/src/scalar.rs | 6 +
arrow/benches/arithmetic_kernels.rs | 40 +-
arrow/src/compute/kernels/mod.rs | 4 +-
arrow/src/ffi.rs | 34 +-
9 files changed, 830 insertions(+), 742 deletions(-)
diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs
index 4961d7efc..04417c666 100644
--- a/arrow-arith/src/aggregate.rs
+++ b/arrow-arith/src/aggregate.rs
@@ -867,8 +867,8 @@ where
#[cfg(test)]
mod tests {
use super::*;
- use crate::arithmetic::add;
use arrow_array::types::*;
+ use arrow_buffer::NullBuffer;
use std::sync::Arc;
#[test]
@@ -897,54 +897,35 @@ mod tests {
#[test]
fn test_primitive_array_sum_large_64() {
- let a: Int64Array = (1..=100)
- .map(|i| if i % 3 == 0 { Some(i) } else { None })
- .collect();
- let b: Int64Array = (1..=100)
- .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) })
- .collect();
// create an array that actually has non-zero values at the invalid
indices
- let c = add(&a, &b).unwrap();
+ let validity = NullBuffer::new((1..=100).map(|x| x % 3 ==
0).collect());
+ let c = Int64Array::new((1..=100).collect(), Some(validity));
+
assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c));
}
#[test]
fn test_primitive_array_sum_large_32() {
- let a: Int32Array = (1..=100)
- .map(|i| if i % 3 == 0 { Some(i) } else { None })
- .collect();
- let b: Int32Array = (1..=100)
- .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) })
- .collect();
// create an array that actually has non-zero values at the invalid
indices
- let c = add(&a, &b).unwrap();
+ let validity = NullBuffer::new((1..=100).map(|x| x % 3 ==
0).collect());
+ let c = Int32Array::new((1..=100).collect(), Some(validity));
assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c));
}
#[test]
fn test_primitive_array_sum_large_16() {
- let a: Int16Array = (1..=100)
- .map(|i| if i % 3 == 0 { Some(i) } else { None })
- .collect();
- let b: Int16Array = (1..=100)
- .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) })
- .collect();
// create an array that actually has non-zero values at the invalid
indices
- let c = add(&a, &b).unwrap();
+ let validity = NullBuffer::new((1..=100).map(|x| x % 3 ==
0).collect());
+ let c = Int16Array::new((1..=100).collect(), Some(validity));
assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c));
}
#[test]
fn test_primitive_array_sum_large_8() {
// include fewer values than other large tests so the result does not
overflow the u8
- let a: UInt8Array = (1..=100)
- .map(|i| if i % 33 == 0 { Some(i) } else { None })
- .collect();
- let b: UInt8Array = (1..=100)
- .map(|i| if i % 33 == 0 { Some(0) } else { Some(i) })
- .collect();
// create an array that actually has non-zero values at the invalid
indices
- let c = add(&a, &b).unwrap();
+ let validity = NullBuffer::new((1..=100).map(|x| x % 33 ==
0).collect());
+ let c = UInt8Array::new((1..=100).collect(), Some(validity));
assert_eq!(Some((1..=100).filter(|i| i % 33 == 0).sum()), sum(&c));
}
diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs
index 8e7ab4404..4f6ecc78d 100644
--- a/arrow-arith/src/arithmetic.rs
+++ b/arrow-arith/src/arithmetic.rs
@@ -23,7 +23,6 @@
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
use crate::arity::*;
-use arrow_array::cast::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::i256;
@@ -39,6 +38,7 @@ use std::sync::Arc;
/// # Errors
///
/// This function errors if the arrays have different lengths
+#[deprecated(note = "Use arrow_arith::arity::binary")]
pub fn math_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
@@ -52,43 +52,6 @@ where
binary(left, right, op)
}
-/// This is similar to `math_op` as it performs given operation between two
input primitive arrays.
-/// But the given operation can return `Err` if overflow is detected. For the
case, this function
-/// returns an `Err`.
-fn math_checked_op<LT, RT, F>(
- left: &PrimitiveArray<LT>,
- right: &PrimitiveArray<RT>,
- op: F,
-) -> Result<PrimitiveArray<LT>, ArrowError>
-where
- LT: ArrowNumericType,
- RT: ArrowNumericType,
- F: Fn(LT::Native, RT::Native) -> Result<LT::Native, ArrowError>,
-{
- try_binary(left, right, op)
-}
-
-/// Helper function for operations where a valid `0` on the right array should
-/// result in an [ArrowError::DivideByZero], namely the division and modulo
operations
-///
-/// # Errors
-///
-/// This function errors if:
-/// * the arrays have different lengths
-/// * there is an element where both left and right values are valid and the
right value is `0`
-fn math_checked_divide_op<LT, RT, F>(
- left: &PrimitiveArray<LT>,
- right: &PrimitiveArray<RT>,
- op: F,
-) -> Result<PrimitiveArray<LT>, ArrowError>
-where
- LT: ArrowNumericType,
- RT: ArrowNumericType,
- F: Fn(LT::Native, RT::Native) -> Result<LT::Native, ArrowError>,
-{
- math_checked_op(left, right, op)
-}
-
/// Calculates the modulus operation `left % right` on two SIMD inputs.
/// The lower-most bits of `valid_mask` specify which vector lanes are
considered as valid.
///
@@ -335,11 +298,12 @@ where
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `add_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")]
pub fn add<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
- math_op(left, right, |a, b| a.add_wrapping(b))
+ binary(left, right, |a, b| a.add_wrapping(b))
}
/// Perform `left + right` operation on two arrays. If either left or right
value is null
@@ -347,11 +311,12 @@ pub fn add<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `add` instead.
+#[deprecated(note = "Use arrow_arith::numeric::add")]
pub fn add_checked<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
- math_checked_op(left, right, |a, b| a.add_checked(b))
+ try_binary(left, right, |a, b| a.add_checked(b))
}
/// Perform `left + right` operation on two arrays. If either left or right
value is null
@@ -359,176 +324,9 @@ pub fn add_checked<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `add_dyn_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")]
pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef,
ArrowError> {
- match left.data_type() {
- DataType::Date32 => {
- let l = left.as_primitive::<Date32Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date32Type::add_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date32Type::add_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r, Date32Type::add_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Date64 => {
- let l = left.as_primitive::<Date64Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date64Type::add_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date64Type::add_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r, Date64Type::add_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- let l = left.as_primitive::<TimestampSecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampSecondType::add_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampSecondType::add_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampSecondType::add_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
-
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let l = left.as_primitive::<TimestampMicrosecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampMicrosecondType::add_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampMicrosecondType::add_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampMicrosecondType::add_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
-
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- let l = left.as_primitive::<TimestampMillisecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampMillisecondType::add_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampMillisecondType::add_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampMillisecondType::add_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
-
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- let l = left.as_primitive::<TimestampNanosecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampNanosecondType::add_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampNanosecondType::add_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampNanosecondType::add_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
-
- DataType::Interval(_)
- if matches!(
- right.data_type(),
- DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _)
- ) =>
- {
- add_dyn(right, left)
- }
- _ => {
- downcast_primitive_array!(
- (left, right) => {
- math_op(left, right, |a, b| a.add_wrapping(b)).map(|a|
Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
- }
- }
+ crate::numeric::add_wrapping(&left, &right)
}
/// Perform `left + right` operation on two arrays. If either left or right
value is null
@@ -536,71 +334,12 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) ->
Result<ArrayRef, ArrowErr
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `add_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::add")]
pub fn add_dyn_checked(
left: &dyn Array,
right: &dyn Array,
) -> Result<ArrayRef, ArrowError> {
- match left.data_type() {
- DataType::Date32 => {
- let l = left.as_primitive::<Date32Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date32Type::add_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date32Type::add_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r, Date32Type::add_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Date64 => {
- let l = left.as_primitive::<Date64Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date64Type::add_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date64Type::add_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r, Date64Type::add_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- _ => {
- downcast_primitive_array!(
- (left, right) => {
- math_checked_op(left, right, |a, b|
a.add_checked(b)).map(|a| Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
- }
- }
+ crate::numeric::add(&left, &right)
}
/// Add every value in an array by a scalar. If any value in the array is null
then the
@@ -608,6 +347,7 @@ pub fn add_dyn_checked(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `add_scalar_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")]
pub fn add_scalar<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
scalar: T::Native,
@@ -620,6 +360,7 @@ pub fn add_scalar<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `add_scalar` instead.
+#[deprecated(note = "Use arrow_arith::numeric::add")]
pub fn add_scalar_checked<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
scalar: T::Native,
@@ -635,6 +376,7 @@ pub fn add_scalar_checked<T: ArrowNumericType>(
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
///
/// This returns an `Err` when the input array is not supported for adding
operation.
+#[deprecated(note = "Use arrow_arith::numeric::add_wrapping")]
pub fn add_scalar_dyn<T: ArrowNumericType>(
array: &dyn Array,
scalar: T::Native,
@@ -651,6 +393,7 @@ pub fn add_scalar_dyn<T: ArrowNumericType>(
///
/// As this kernel has the branching costs and also prevents LLVM from
vectorising it correctly,
/// it is usually much slower than non-checking variant.
+#[deprecated(note = "Use arrow_arith::numeric::add")]
pub fn add_scalar_checked_dyn<T: ArrowNumericType>(
array: &dyn Array,
scalar: T::Native,
@@ -664,11 +407,12 @@ pub fn add_scalar_checked_dyn<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `subtract_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")]
pub fn subtract<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
- math_op(left, right, |a, b| a.sub_wrapping(b))
+ binary(left, right, |a, b| a.sub_wrapping(b))
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -676,11 +420,12 @@ pub fn subtract<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `subtract` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub")]
pub fn subtract_checked<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
- math_checked_op(left, right, |a, b| a.sub_checked(b))
+ try_binary(left, right, |a, b| a.sub_checked(b))
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -688,184 +433,9 @@ pub fn subtract_checked<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `subtract_dyn_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")]
pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef,
ArrowError> {
- match left.data_type() {
- DataType::Date32 => {
- let l = left.as_primitive::<Date32Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date32Type::subtract_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date32Type::subtract_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r,
Date32Type::subtract_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Date64 => {
- let l = left.as_primitive::<Date64Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date64Type::subtract_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date64Type::subtract_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r,
Date64Type::subtract_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- let l = left.as_primitive::<TimestampSecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampSecondType::subtract_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampSecondType::subtract_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampSecondType::subtract_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- let r = right.as_primitive::<TimestampSecondType>();
- let res: PrimitiveArray<DurationSecondType> = binary(l, r,
|a, b| a.wrapping_sub(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let l = left.as_primitive::<TimestampMicrosecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampMicrosecondType::subtract_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampMicrosecondType::subtract_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampMicrosecondType::subtract_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let r = right.as_primitive::<TimestampMicrosecondType>();
- let res: PrimitiveArray<DurationMicrosecondType> =
binary(l, r, |a, b| a.wrapping_sub(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- let l = left.as_primitive::<TimestampMillisecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampMillisecondType::subtract_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampMillisecondType::subtract_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampMillisecondType::subtract_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- let r = right.as_primitive::<TimestampMillisecondType>();
- let res: PrimitiveArray<DurationMillisecondType> =
binary(l, r, |a, b| a.wrapping_sub(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- let l = left.as_primitive::<TimestampNanosecondType>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_checked_op(l, r,
TimestampNanosecondType::subtract_year_months)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_checked_op(l, r,
TimestampNanosecondType::subtract_day_time)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_checked_op(l, r,
TimestampNanosecondType::subtract_month_day_nano)?;
- Ok(Arc::new(res.with_timezone_opt(l.timezone())))
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- let r = right.as_primitive::<TimestampNanosecondType>();
- let res: PrimitiveArray<DurationNanosecondType> =
binary(l, r, |a, b| a.wrapping_sub(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- _ => {
- downcast_primitive_array!(
- (left, right) => {
- math_op(left, right, |a, b| a.sub_wrapping(b)).map(|a|
Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
- }
- }
+ crate::numeric::sub_wrapping(&left, &right)
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -873,127 +443,12 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn Array)
-> Result<ArrayRef, Arr
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `subtract_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub")]
pub fn subtract_dyn_checked(
left: &dyn Array,
right: &dyn Array,
) -> Result<ArrayRef, ArrowError> {
- match left.data_type() {
- DataType::Date32 => {
- let l = left.as_primitive::<Date32Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date32Type::subtract_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date32Type::subtract_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r,
Date32Type::subtract_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Date64 => {
- let l = left.as_primitive::<Date64Type>();
- match right.data_type() {
- DataType::Interval(IntervalUnit::YearMonth) => {
- let r = right.as_primitive::<IntervalYearMonthType>();
- let res = math_op(l, r, Date64Type::subtract_year_months)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- let r = right.as_primitive::<IntervalDayTimeType>();
- let res = math_op(l, r, Date64Type::subtract_day_time)?;
- Ok(Arc::new(res))
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- let r = right.as_primitive::<IntervalMonthDayNanoType>();
- let res = math_op(l, r,
Date64Type::subtract_month_day_nano)?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- let l = left.as_primitive::<TimestampSecondType>();
- match right.data_type() {
- DataType::Timestamp(TimeUnit::Second, _) => {
- let r = right.as_primitive::<TimestampSecondType>();
- let res: PrimitiveArray<DurationSecondType> =
try_binary(l, r, |a, b| a.sub_checked(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let l = left.as_primitive::<TimestampMicrosecondType>();
- match right.data_type() {
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let r = right.as_primitive::<TimestampMicrosecondType>();
- let res: PrimitiveArray<DurationMicrosecondType> =
try_binary(l, r, |a, b| a.sub_checked(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- let l = left.as_primitive::<TimestampMillisecondType>();
- match right.data_type() {
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- let r = right.as_primitive::<TimestampMillisecondType>();
- let res: PrimitiveArray<DurationMillisecondType> =
try_binary(l, r, |a, b| a.sub_checked(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- let l = left.as_primitive::<TimestampNanosecondType>();
- match right.data_type() {
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- let r = right.as_primitive::<TimestampNanosecondType>();
- let res: PrimitiveArray<DurationNanosecondType> =
try_binary(l, r, |a, b| a.sub_checked(b))?;
- Ok(Arc::new(res))
- }
- _ => Err(ArrowError::CastError(format!(
- "Cannot perform arithmetic operation between array of type
{} and array of type {}",
- left.data_type(), right.data_type()
- ))),
- }
- }
- _ => {
- downcast_primitive_array!(
- (left, right) => {
- math_checked_op(left, right, |a, b|
a.sub_checked(b)).map(|a| Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
- }
- }
+ crate::numeric::sub(&left, &right)
}
/// Subtract every value in an array by a scalar. If any value in the array is
null then the
@@ -1001,6 +456,7 @@ pub fn subtract_dyn_checked(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `subtract_scalar_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")]
pub fn subtract_scalar<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
scalar: T::Native,
@@ -1013,6 +469,7 @@ pub fn subtract_scalar<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `subtract_scalar` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub")]
pub fn subtract_scalar_checked<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
scalar: T::Native,
@@ -1026,6 +483,7 @@ pub fn subtract_scalar_checked<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn`
instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub_wrapping")]
pub fn subtract_scalar_dyn<T: ArrowNumericType>(
array: &dyn Array,
scalar: T::Native,
@@ -1039,6 +497,7 @@ pub fn subtract_scalar_dyn<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `subtract_scalar_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::sub")]
pub fn subtract_scalar_checked_dyn<T: ArrowNumericType>(
array: &dyn Array,
scalar: T::Native,
@@ -1072,11 +531,12 @@ pub fn negate_checked<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `multiply_check` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")]
pub fn multiply<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
- math_op(left, right, |a, b| a.mul_wrapping(b))
+ binary(left, right, |a, b| a.mul_wrapping(b))
}
/// Perform `left * right` operation on two arrays. If either left or right
value is null
@@ -1084,11 +544,12 @@ pub fn multiply<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `multiply` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul")]
pub fn multiply_checked<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
- math_checked_op(left, right, |a, b| a.mul_checked(b))
+ try_binary(left, right, |a, b| a.mul_checked(b))
}
/// Perform `left * right` operation on two arrays. If either left or right
value is null
@@ -1096,16 +557,9 @@ pub fn multiply_checked<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `multiply_dyn_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")]
pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef,
ArrowError> {
- downcast_primitive_array!(
- (left, right) => {
- math_op(left, right, |a, b| a.mul_wrapping(b)).map(|a| Arc::new(a)
as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
+ crate::numeric::mul_wrapping(&left, &right)
}
/// Perform `left * right` operation on two arrays. If either left or right
value is null
@@ -1113,19 +567,12 @@ pub fn multiply_dyn(left: &dyn Array, right: &dyn Array)
-> Result<ArrayRef, Arr
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `multiply_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul")]
pub fn multiply_dyn_checked(
left: &dyn Array,
right: &dyn Array,
) -> Result<ArrayRef, ArrowError> {
- downcast_primitive_array!(
- (left, right) => {
- math_checked_op(left, right, |a, b| a.mul_checked(b)).map(|a|
Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
+ crate::numeric::mul(&left, &right)
}
/// Returns the precision and scale of the result of a multiplication of two
decimal types,
@@ -1210,8 +657,10 @@ pub fn multiply_fixed_point_checked(
)?;
if required_scale == product_scale {
- return multiply_checked(left, right)?
- .with_precision_and_scale(precision, required_scale);
+ return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
+ a.mul_checked(b)
+ })?
+ .with_precision_and_scale(precision, required_scale);
}
try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
@@ -1254,7 +703,7 @@ pub fn multiply_fixed_point(
)?;
if required_scale == product_scale {
- return multiply(left, right)?
+ return binary(left, right, |a, b| a.mul_wrapping(b))?
.with_precision_and_scale(precision, required_scale);
}
@@ -1294,6 +743,7 @@ where
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `multiply_scalar_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")]
pub fn multiply_scalar<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
scalar: T::Native,
@@ -1306,6 +756,7 @@ pub fn multiply_scalar<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `multiply_scalar` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul")]
pub fn multiply_scalar_checked<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
scalar: T::Native,
@@ -1319,6 +770,7 @@ pub fn multiply_scalar_checked<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn`
instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul_wrapping")]
pub fn multiply_scalar_dyn<T: ArrowNumericType>(
array: &dyn Array,
scalar: T::Native,
@@ -1332,6 +784,7 @@ pub fn multiply_scalar_dyn<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `multiply_scalar_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::mul")]
pub fn multiply_scalar_checked_dyn<T: ArrowNumericType>(
array: &dyn Array,
scalar: T::Native,
@@ -1343,6 +796,7 @@ pub fn multiply_scalar_checked_dyn<T: ArrowNumericType>(
/// Perform `left % right` operation on two arrays. If either left or right
value is null
/// then the result is also null. If any right hand value is zero then the
result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
+#[deprecated(note = "Use arrow_arith::numeric::rem")]
pub fn modulus<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
@@ -1364,22 +818,9 @@ pub fn modulus<T: ArrowNumericType>(
/// Perform `left % right` operation on two arrays. If either left or right
value is null
/// then the result is also null. If any right hand value is zero then the
result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
+#[deprecated(note = "Use arrow_arith::numeric::rem")]
pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef,
ArrowError> {
- downcast_primitive_array!(
- (left, right) => {
- math_checked_divide_op(left, right, |a, b| {
- if b.is_zero() {
- Err(ArrowError::DivideByZero)
- } else {
- Ok(a.mod_wrapping(b))
- }
- }).map(|a| Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
+ crate::numeric::rem(&left, &right)
}
/// Perform `left / right` operation on two arrays. If either left or right
value is null
@@ -1388,6 +829,7 @@ pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) ->
Result<ArrayRef, Arro
///
/// When `simd` feature is not enabled. This detects overflow and returns an
`Err` for that.
/// For an non-overflow-checking variant, use `divide` instead.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_checked<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
@@ -1397,7 +839,7 @@ pub fn divide_checked<T: ArrowNumericType>(
a.div_wrapping(b)
});
#[cfg(not(feature = "simd"))]
- return math_checked_divide_op(left, right, |a, b| a.div_checked(b));
+ return try_binary(left, right, |a, b| a.div_checked(b));
}
/// Perform `left / right` operation on two arrays. If either left or right
value is null
@@ -1414,6 +856,7 @@ pub fn divide_checked<T: ArrowNumericType>(
///
/// For integer types overflow will wrap around.
///
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_opt<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
@@ -1433,17 +876,23 @@ pub fn divide_opt<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `divide_dyn_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef,
ArrowError> {
+ fn divide_op<T: ArrowPrimitiveType>(
+ left: &PrimitiveArray<T>,
+ right: &PrimitiveArray<T>,
+ ) -> Result<PrimitiveArray<T>, ArrowError> {
+ try_binary(left, right, |a, b| {
+ if b.is_zero() {
+ Err(ArrowError::DivideByZero)
+ } else {
+ Ok(a.div_wrapping(b))
+ }
+ })
+ }
+
downcast_primitive_array!(
- (left, right) => {
- math_checked_divide_op(left, right, |a, b| {
- if b.is_zero() {
- Err(ArrowError::DivideByZero)
- } else {
- Ok(a.div_wrapping(b))
- }
- }).map(|a| Arc::new(a) as ArrayRef)
- }
+ (left, right) => divide_op(left, right).map(|a| Arc::new(a) as
ArrayRef),
_ => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
left.data_type(), right.data_type()
@@ -1457,19 +906,12 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array)
-> Result<ArrayRef, Arrow
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `divide_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_dyn_checked(
left: &dyn Array,
right: &dyn Array,
) -> Result<ArrayRef, ArrowError> {
- downcast_primitive_array!(
- (left, right) => {
- math_checked_divide_op(left, right, |a, b|
a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef)
- }
- _ => Err(ArrowError::CastError(format!(
- "Unsupported data type {}, {}",
- left.data_type(), right.data_type()
- )))
- )
+ crate::numeric::div(&left, &right)
}
/// Perform `left / right` operation on two arrays. If either left or right
value is null
@@ -1481,6 +923,7 @@ pub fn divide_dyn_checked(
/// Unlike `divide_dyn` or `divide_dyn_checked`, division by zero will get a
null value instead
/// returning an `Err`, this also doesn't check overflowing, overflowing will
just wrap
/// the result around.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_dyn_opt(
left: &dyn Array,
right: &dyn Array,
@@ -1513,18 +956,20 @@ pub fn divide_dyn_opt(
/// If either left or right value is null then the result is also null.
///
/// For an overflow-checking variant, use `divide_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide<T: ArrowNumericType>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
// TODO: This is incorrect as div_wrapping has side-effects for integer
types
// and so may panic on null values (#2647)
- math_op(left, right, |a, b| a.div_wrapping(b))
+ binary(left, right, |a, b| a.div_wrapping(b))
}
/// Modulus every value in an array by a scalar. If any value in the array is
null then the
/// result is also null. If the scalar is zero then the result of this
operation will be
/// `Err(ArrowError::DivideByZero)`.
+#[deprecated(note = "Use arrow_arith::numeric::rem")]
pub fn modulus_scalar<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
modulo: T::Native,
@@ -1539,6 +984,7 @@ pub fn modulus_scalar<T: ArrowNumericType>(
/// Modulus every value in an array by a scalar. If any value in the array is
null then the
/// result is also null. If the scalar is zero then the result of this
operation will be
/// `Err(ArrowError::DivideByZero)`.
+#[deprecated(note = "Use arrow_arith::numeric::rem")]
pub fn modulus_scalar_dyn<T: ArrowNumericType>(
array: &dyn Array,
modulo: T::Native,
@@ -1552,6 +998,7 @@ pub fn modulus_scalar_dyn<T: ArrowNumericType>(
/// Divide every value in an array by a scalar. If any value in the array is
null then the
/// result is also null. If the scalar is zero then the result of this
operation will be
/// `Err(ArrowError::DivideByZero)`.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_scalar<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
divisor: T::Native,
@@ -1569,6 +1016,7 @@ pub fn divide_scalar<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_scalar_dyn<T: ArrowNumericType>(
array: &dyn Array,
divisor: T::Native,
@@ -1586,6 +1034,7 @@ pub fn divide_scalar_dyn<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `divide_scalar_dyn` instead.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_scalar_checked_dyn<T: ArrowNumericType>(
array: &dyn Array,
divisor: T::Native,
@@ -1608,6 +1057,7 @@ pub fn divide_scalar_checked_dyn<T: ArrowNumericType>(
/// Unlike `divide_scalar_dyn` or `divide_scalar_checked_dyn`, division by
zero will get a
/// null value instead returning an `Err`, this also doesn't check
overflowing, overflowing
/// will just wrap the result around.
+#[deprecated(note = "Use arrow_arith::numeric::div")]
pub fn divide_scalar_opt_dyn<T: ArrowNumericType>(
array: &dyn Array,
divisor: T::Native,
@@ -1625,11 +1075,13 @@ pub fn divide_scalar_opt_dyn<T: ArrowNumericType>(
}
#[cfg(test)]
+#[allow(deprecated)]
mod tests {
use super::*;
use arrow_array::builder::{
BooleanBufferBuilder, BufferBuilder, PrimitiveDictionaryBuilder,
};
+ use arrow_array::cast::AsArray;
use arrow_array::temporal_conversions::SECONDS_IN_DAY;
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::i256;
@@ -1678,16 +1130,14 @@ mod tests {
)]);
let b =
IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
- let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date32Type>().value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1,
2).unwrap())
);
let c = add_dyn(&b, &a).unwrap();
- let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date32Type>().value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1,
2).unwrap())
);
}
@@ -1702,16 +1152,14 @@ mod tests {
1, 2, 3,
)]);
let c = add_dyn(&a, &b).unwrap();
- let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date32Type>().value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2,
3).unwrap())
);
let c = add_dyn(&b, &a).unwrap();
- let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date32Type>().value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2,
3).unwrap())
);
}
@@ -1724,16 +1172,14 @@ mod tests {
let b =
IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
- let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date64Type>().value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3,
1).unwrap())
);
let c = add_dyn(&b, &a).unwrap();
- let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date64Type>().value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3,
1).unwrap())
);
}
@@ -1745,16 +1191,14 @@ mod tests {
)]);
let b =
IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
- let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date64Type>().value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1,
2).unwrap())
);
let c = add_dyn(&b, &a).unwrap();
- let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date64Type>().value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1,
2).unwrap())
);
}
@@ -1769,16 +1213,14 @@ mod tests {
1, 2, 3,
)]);
let c = add_dyn(&a, &b).unwrap();
- let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date64Type>().value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2,
3).unwrap())
);
let c = add_dyn(&b, &a).unwrap();
- let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
- c.value(0),
+ c.as_primitive::<Date64Type>().value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2,
3).unwrap())
);
}
@@ -2584,11 +2026,11 @@ mod tests {
}
#[test]
- #[should_panic(expected = "DivideByZero")]
fn test_f32_array_modulus_dyn_by_zero() {
let a = Float32Array::from(vec![1.5]);
let b = Float32Array::from(vec![0.0]);
- modulus_dyn(&a, &b).unwrap();
+ let result = modulus_dyn(&a, &b).unwrap();
+ assert!(result.as_primitive::<Float32Type>().value(0).is_nan());
}
#[test]
@@ -3838,10 +3280,6 @@ mod tests {
<TimestampSecondType as ArrowPrimitiveType>::Native::MIN,
]);
- // unchecked
- let result = subtract_dyn(&a, &b);
- assert!(!&result.is_err());
-
// checked
let result = subtract_dyn_checked(&a, &b);
assert!(&result.is_err());
@@ -3866,16 +3304,8 @@ mod tests {
#[test]
fn test_timestamp_microsecond_subtract_timestamp_overflow() {
- let a = TimestampMicrosecondArray::from(vec![
- <TimestampMicrosecondType as ArrowPrimitiveType>::Native::MAX,
- ]);
- let b = TimestampMicrosecondArray::from(vec![
- <TimestampMicrosecondType as ArrowPrimitiveType>::Native::MIN,
- ]);
-
- // unchecked
- let result = subtract_dyn(&a, &b);
- assert!(!&result.is_err());
+ let a = TimestampMicrosecondArray::from(vec![i64::MAX]);
+ let b = TimestampMicrosecondArray::from(vec![i64::MIN]);
// checked
let result = subtract_dyn_checked(&a, &b);
@@ -3901,16 +3331,8 @@ mod tests {
#[test]
fn test_timestamp_millisecond_subtract_timestamp_overflow() {
- let a = TimestampMillisecondArray::from(vec![
- <TimestampMillisecondType as ArrowPrimitiveType>::Native::MAX,
- ]);
- let b = TimestampMillisecondArray::from(vec![
- <TimestampMillisecondType as ArrowPrimitiveType>::Native::MIN,
- ]);
-
- // unchecked
- let result = subtract_dyn(&a, &b);
- assert!(!&result.is_err());
+ let a = TimestampMillisecondArray::from(vec![i64::MAX]);
+ let b = TimestampMillisecondArray::from(vec![i64::MIN]);
// checked
let result = subtract_dyn_checked(&a, &b);
@@ -3943,10 +3365,6 @@ mod tests {
<TimestampNanosecondType as ArrowPrimitiveType>::Native::MIN,
]);
- // unchecked
- let result = subtract_dyn(&a, &b);
- assert!(!&result.is_err());
-
// checked
let result = subtract_dyn_checked(&a, &b);
assert!(&result.is_err());
diff --git a/arrow-arith/src/lib.rs b/arrow-arith/src/lib.rs
index 60d31c972..2d5451e04 100644
--- a/arrow-arith/src/lib.rs
+++ b/arrow-arith/src/lib.rs
@@ -18,8 +18,10 @@
//! Arrow arithmetic and aggregation kernels
pub mod aggregate;
+#[doc(hidden)] // Kernels to be removed in a future release
pub mod arithmetic;
pub mod arity;
pub mod bitwise;
pub mod boolean;
+pub mod numeric;
pub mod temporal;
diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs
new file mode 100644
index 000000000..816fcaa94
--- /dev/null
+++ b/arrow-arith/src/numeric.rs
@@ -0,0 +1,672 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`]
+
+use std::cmp::Ordering;
+use std::sync::Arc;
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::*;
+use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
+use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit};
+
+use crate::arity::{binary, try_binary};
+
+/// Perform `lhs + rhs`, returning an error on overflow
+pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
+ arithmetic_op(Op::Add, lhs, rhs)
+}
+
+/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`]
+pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef,
ArrowError> {
+ arithmetic_op(Op::AddWrapping, lhs, rhs)
+}
+
+/// Perform `lhs - rhs`, returning an error on overflow
+pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
+ arithmetic_op(Op::Sub, lhs, rhs)
+}
+
+/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`]
+pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef,
ArrowError> {
+ arithmetic_op(Op::SubWrapping, lhs, rhs)
+}
+
+/// Perform `lhs * rhs`, returning an error on overflow
+pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
+ arithmetic_op(Op::Mul, lhs, rhs)
+}
+
+/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`]
+pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef,
ArrowError> {
+ arithmetic_op(Op::MulWrapping, lhs, rhs)
+}
+
+/// Perform `lhs / rhs`
+///
+/// Overflow or division by zero will result in an error, with exception to
+/// floating point numbers, which instead follow the IEEE 754 rules
+pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
+ arithmetic_op(Op::Div, lhs, rhs)
+}
+
+/// Perform `lhs % rhs`
+///
+/// Overflow or division by zero will result in an error, with exception to
+/// floating point numbers, which instead follow the IEEE 754 rules
+pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
+ arithmetic_op(Op::Rem, lhs, rhs)
+}
+
+/// An enumeration of arithmetic operations
+///
+/// This allows sharing the type dispatch logic across the various kernels
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ AddWrapping,
+ Add,
+ SubWrapping,
+ Sub,
+ MulWrapping,
+ Mul,
+ Div,
+ Rem,
+}
+
+impl Op {
+ fn commutative(&self) -> bool {
+ matches!(self, Self::Add | Self::AddWrapping)
+ }
+}
+
+/// Dispatch the given `op` to the appropriate specialized kernel
+fn arithmetic_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<ArrayRef, ArrowError> {
+ use DataType::*;
+ use IntervalUnit::*;
+ use TimeUnit::*;
+
+ macro_rules! integer_helper {
+ ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident,
$r_scalar:ident) => {
+ integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar)
+ };
+ }
+
+ let (l, l_scalar) = lhs.get();
+ let (r, r_scalar) = rhs.get();
+ downcast_integer! {
+ l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r,
r_scalar),
+ (Float16, Float16) => float_op::<Float16Type>(op, l, l_scalar, r,
r_scalar),
+ (Float32, Float32) => float_op::<Float32Type>(op, l, l_scalar, r,
r_scalar),
+ (Float64, Float64) => float_op::<Float64Type>(op, l, l_scalar, r,
r_scalar),
+ (Timestamp(Second, _), _) => timestamp_op::<TimestampSecondType>(op,
l, l_scalar, r, r_scalar),
+ (Timestamp(Millisecond, _), _) =>
timestamp_op::<TimestampMillisecondType>(op, l, l_scalar, r, r_scalar),
+ (Timestamp(Microsecond, _), _) =>
timestamp_op::<TimestampMicrosecondType>(op, l, l_scalar, r, r_scalar),
+ (Timestamp(Nanosecond, _), _) =>
timestamp_op::<TimestampNanosecondType>(op, l, l_scalar, r, r_scalar),
+ (Duration(Second), Duration(Second)) =>
duration_op::<DurationSecondType>(op, l, l_scalar, r, r_scalar),
+ (Duration(Millisecond), Duration(Millisecond)) =>
duration_op::<DurationMillisecondType>(op, l, l_scalar, r, r_scalar),
+ (Duration(Microsecond), Duration(Microsecond)) =>
duration_op::<DurationMicrosecondType>(op, l, l_scalar, r, r_scalar),
+ (Duration(Nanosecond), Duration(Nanosecond)) =>
duration_op::<DurationNanosecondType>(op, l, l_scalar, r, r_scalar),
+ (Interval(YearMonth), Interval(YearMonth)) =>
interval_op::<IntervalYearMonthType>(op, l, l_scalar, r, r_scalar),
+ (Interval(DayTime), Interval(DayTime)) =>
interval_op::<IntervalDayTimeType>(op, l, l_scalar, r, r_scalar),
+ (Interval(MonthDayNano), Interval(MonthDayNano)) =>
interval_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
+ (Date32, _) => date_op::<Date32Type>(op, l, l_scalar, r, r_scalar),
+ (Date64, _) => date_op::<Date64Type>(op, l, l_scalar, r, r_scalar),
+ (Decimal128(_, _), Decimal128(_, _)) =>
decimal_op::<Decimal128Type>(op, l, l_scalar, r, r_scalar),
+ (Decimal256(_, _), Decimal256(_, _)) =>
decimal_op::<Decimal256Type>(op, l, l_scalar, r, r_scalar),
+ (l_t, r_t) => match (l_t, r_t) {
+ (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if
op.commutative() => {
+ arithmetic_op(op, rhs, lhs)
+ }
+ _ => Err(ArrowError::InvalidArgumentError(
+ format!("Invalid arithmetic operation: {l_t} {op:?} {r_t}")
+ ))
+ }
+ }
+}
+
+/// Perform an infallible binary operation on potentially scalar inputs
+macro_rules! op {
+ ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {
+ match ($l_s, $r_s) {
+ (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?,
+ (true, false) => match ($l.null_count() == 0).then(|| $l.value(0))
{
+ None => PrimitiveArray::new_null($r.len()),
+ Some($l) => $r.unary(|$r| $op),
+ },
+ (false, true) => match ($r.null_count() == 0).then(|| $r.value(0))
{
+ None => PrimitiveArray::new_null($l.len()),
+ Some($r) => $l.unary(|$l| $op),
+ },
+ }
+ };
+}
+
+/// Same as `op` but with a type hint for the returned array
+macro_rules! op_ref {
+ ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{
+ let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op);
+ Arc::new(array)
+ }};
+}
+
+/// Perform a fallible binary operation on potentially scalar inputs
+macro_rules! try_op {
+ ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {
+ match ($l_s, $r_s) {
+ (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?,
+ (true, false) => match ($l.null_count() == 0).then(|| $l.value(0))
{
+ None => PrimitiveArray::new_null($r.len()),
+ Some($l) => $r.try_unary(|$r| $op)?,
+ },
+ (false, true) => match ($r.null_count() == 0).then(|| $r.value(0))
{
+ None => PrimitiveArray::new_null($l.len()),
+ Some($r) => $l.try_unary(|$l| $op)?,
+ },
+ }
+ };
+}
+
+/// Same as `try_op` but with a type hint for the returned array
+macro_rules! try_op_ref {
+ ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{
+ let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op);
+ Arc::new(array)
+ }};
+}
+
+/// Perform an arithmetic operation on integers
+fn integer_op<T: ArrowPrimitiveType>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ let l = l.as_primitive::<T>();
+ let r = r.as_primitive::<T>();
+ let array: PrimitiveArray<T> = match op {
+ Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)),
+ Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)),
+ Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
+ Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)),
+ Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
+ Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)),
+ Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)),
+ Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)),
+ };
+ Ok(Arc::new(array))
+}
+
+/// Perform an arithmetic operation on floats
+fn float_op<T: ArrowPrimitiveType>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ let l = l.as_primitive::<T>();
+ let r = r.as_primitive::<T>();
+ let array: PrimitiveArray<T> = match op {
+ Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)),
+ Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
+ Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
+ Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)),
+ Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)),
+ };
+ Ok(Arc::new(array))
+}
+
+/// Arithmetic trait for timestamp arrays
+trait TimestampOp: ArrowTimestampType {
+ type Duration: ArrowPrimitiveType<Native = i64>;
+
+ fn add_year_month(timestamp: i64, delta: i32) -> Result<i64, ArrowError>;
+ fn add_day_time(timestamp: i64, delta: i64) -> Result<i64, ArrowError>;
+ fn add_month_day_nano(timestamp: i64, delta: i128) -> Result<i64,
ArrowError>;
+
+ fn sub_year_month(timestamp: i64, delta: i32) -> Result<i64, ArrowError>;
+ fn sub_day_time(timestamp: i64, delta: i64) -> Result<i64, ArrowError>;
+ fn sub_month_day_nano(timestamp: i64, delta: i128) -> Result<i64,
ArrowError>;
+}
+
+macro_rules! timestamp {
+ ($t:ty, $d:ty) => {
+ impl TimestampOp for $t {
+ type Duration = $d;
+
+ fn add_year_month(left: i64, right: i32) -> Result<i64,
ArrowError> {
+ Self::add_year_months(left, right)
+ }
+
+ fn add_day_time(left: i64, right: i64) -> Result<i64, ArrowError> {
+ Self::add_day_time(left, right)
+ }
+
+ fn add_month_day_nano(left: i64, right: i128) -> Result<i64,
ArrowError> {
+ Self::add_month_day_nano(left, right)
+ }
+
+ fn sub_year_month(left: i64, right: i32) -> Result<i64,
ArrowError> {
+ Self::subtract_year_months(left, right)
+ }
+
+ fn sub_day_time(left: i64, right: i64) -> Result<i64, ArrowError> {
+ Self::subtract_day_time(left, right)
+ }
+
+ fn sub_month_day_nano(left: i64, right: i128) -> Result<i64,
ArrowError> {
+ Self::subtract_month_day_nano(left, right)
+ }
+ }
+ };
+}
+timestamp!(TimestampSecondType, DurationSecondType);
+timestamp!(TimestampMillisecondType, DurationMillisecondType);
+timestamp!(TimestampMicrosecondType, DurationMicrosecondType);
+timestamp!(TimestampNanosecondType, DurationNanosecondType);
+
+/// Perform arithmetic operation on a timestamp array
+fn timestamp_op<T: TimestampOp>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ use DataType::*;
+ use IntervalUnit::*;
+
+ // Note: interval arithmetic should account for timezones (#4457)
+ let l = l.as_primitive::<T>();
+ let array: PrimitiveArray<T> = match (op, r.data_type()) {
+ (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT =>
{
+ let r = r.as_primitive::<T>();
+ return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s,
l.sub_checked(r)));
+ }
+
+ (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => {
+ let r = r.as_primitive::<T::Duration>();
+ try_op!(l, l_s, r, r_s, l.add_checked(r))
+ }
+ (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => {
+ let r = r.as_primitive::<T::Duration>();
+ try_op!(l, l_s, r, r_s, l.sub_checked(r))
+ }
+
+ (Op::Add | Op::AddWrapping, Interval(YearMonth)) => {
+ let r = r.as_primitive::<IntervalYearMonthType>();
+ try_op!(l, l_s, r, r_s, T::add_year_month(l, r))
+ }
+ (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => {
+ let r = r.as_primitive::<IntervalYearMonthType>();
+ try_op!(l, l_s, r, r_s, T::sub_year_month(l, r))
+ }
+
+ (Op::Add | Op::AddWrapping, Interval(DayTime)) => {
+ let r = r.as_primitive::<IntervalDayTimeType>();
+ try_op!(l, l_s, r, r_s, T::add_day_time(l, r))
+ }
+ (Op::Sub | Op::SubWrapping, Interval(DayTime)) => {
+ let r = r.as_primitive::<IntervalDayTimeType>();
+ try_op!(l, l_s, r, r_s, T::sub_day_time(l, r))
+ }
+
+ (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => {
+ let r = r.as_primitive::<IntervalMonthDayNanoType>();
+ try_op!(l, l_s, r, r_s, T::add_month_day_nano(l, r))
+ }
+ (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => {
+ let r = r.as_primitive::<IntervalMonthDayNanoType>();
+ try_op!(l, l_s, r, r_s, T::sub_month_day_nano(l, r))
+ }
+ _ => {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Invalid timestamp arithmetic operation: {} {op:?} {}",
+ l.data_type(),
+ r.data_type()
+ )))
+ }
+ };
+ Ok(Arc::new(array.with_timezone_opt(l.timezone())))
+}
+
+/// Arithmetic trait for date arrays
+///
+/// Note: these should be fallible (#4456)
+trait DateOp: ArrowTemporalType {
+ fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native;
+ fn add_day_time(timestamp: Self::Native, delta: i64) -> Self::Native;
+ fn add_month_day_nano(timestamp: Self::Native, delta: i128) ->
Self::Native;
+
+ fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native;
+ fn sub_day_time(timestamp: Self::Native, delta: i64) -> Self::Native;
+ fn sub_month_day_nano(timestamp: Self::Native, delta: i128) ->
Self::Native;
+}
+
+macro_rules! date {
+ ($t:ty) => {
+ impl DateOp for $t {
+ fn add_year_month(left: Self::Native, right: i32) -> Self::Native {
+ Self::add_year_months(left, right)
+ }
+
+ fn add_day_time(left: Self::Native, right: i64) -> Self::Native {
+ Self::add_day_time(left, right)
+ }
+
+ fn add_month_day_nano(left: Self::Native, right: i128) ->
Self::Native {
+ Self::add_month_day_nano(left, right)
+ }
+
+ fn sub_year_month(left: Self::Native, right: i32) -> Self::Native {
+ Self::subtract_year_months(left, right)
+ }
+
+ fn sub_day_time(left: Self::Native, right: i64) -> Self::Native {
+ Self::subtract_day_time(left, right)
+ }
+
+ fn sub_month_day_nano(left: Self::Native, right: i128) ->
Self::Native {
+ Self::subtract_month_day_nano(left, right)
+ }
+ }
+ };
+}
+date!(Date32Type);
+date!(Date64Type);
+
+/// Arithmetic trait for interval arrays
+trait IntervalOp: ArrowPrimitiveType {
+ fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError>;
+ fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError>;
+}
+
+impl IntervalOp for IntervalYearMonthType {
+ fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError> {
+ left.add_checked(right)
+ }
+
+ fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError> {
+ left.sub_checked(right)
+ }
+}
+
+impl IntervalOp for IntervalDayTimeType {
+ fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError> {
+ let (l_days, l_ms) = Self::to_parts(left);
+ let (r_days, r_ms) = Self::to_parts(right);
+ let days = l_days.add_checked(r_days)?;
+ let ms = l_ms.add_checked(r_ms)?;
+ Ok(Self::make_value(days, ms))
+ }
+
+ fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError> {
+ let (l_days, l_ms) = Self::to_parts(left);
+ let (r_days, r_ms) = Self::to_parts(right);
+ let days = l_days.sub_checked(r_days)?;
+ let ms = l_ms.sub_checked(r_ms)?;
+ Ok(Self::make_value(days, ms))
+ }
+}
+
+impl IntervalOp for IntervalMonthDayNanoType {
+ fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError> {
+ let (l_months, l_days, l_nanos) = Self::to_parts(left);
+ let (r_months, r_days, r_nanos) = Self::to_parts(right);
+ let months = l_months.add_checked(r_months)?;
+ let days = l_days.add_checked(r_days)?;
+ let nanos = l_nanos.add_checked(r_nanos)?;
+ Ok(Self::make_value(months, days, nanos))
+ }
+
+ fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError> {
+ let (l_months, l_days, l_nanos) = Self::to_parts(left);
+ let (r_months, r_days, r_nanos) = Self::to_parts(right);
+ let months = l_months.sub_checked(r_months)?;
+ let days = l_days.sub_checked(r_days)?;
+ let nanos = l_nanos.sub_checked(r_nanos)?;
+ Ok(Self::make_value(months, days, nanos))
+ }
+}
+
+/// Perform arithmetic operation on an interval array
+fn interval_op<T: IntervalOp>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ let l = l.as_primitive::<T>();
+ let r = r.as_primitive::<T>();
+ match op {
+ Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s,
T::add(l, r))),
+ Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s,
T::sub(l, r))),
+ _ => Err(ArrowError::InvalidArgumentError(format!(
+ "Invalid interval arithmetic operation: {} {op:?} {}",
+ l.data_type(),
+ r.data_type()
+ ))),
+ }
+}
+
+fn duration_op<T: ArrowPrimitiveType>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ let l = l.as_primitive::<T>();
+ let r = r.as_primitive::<T>();
+ match op {
+ Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s,
l.add_checked(r))),
+ Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s,
l.sub_checked(r))),
+ _ => Err(ArrowError::InvalidArgumentError(format!(
+ "Invalid duration arithmetic operation: {} {op:?} {}",
+ l.data_type(),
+ r.data_type()
+ ))),
+ }
+}
+
+/// Perform arithmetic operation on a date array
+fn date_op<T: DateOp>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ use DataType::*;
+ use IntervalUnit::*;
+
+ // Note: interval arithmetic should account for timezones (#4457)
+ let l = l.as_primitive::<T>();
+ match (op, r.data_type()) {
+ (Op::Add | Op::AddWrapping, Interval(YearMonth)) => {
+ let r = r.as_primitive::<IntervalYearMonthType>();
+ Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r)))
+ }
+ (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => {
+ let r = r.as_primitive::<IntervalYearMonthType>();
+ Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r)))
+ }
+
+ (Op::Add | Op::AddWrapping, Interval(DayTime)) => {
+ let r = r.as_primitive::<IntervalDayTimeType>();
+ Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r)))
+ }
+ (Op::Sub | Op::SubWrapping, Interval(DayTime)) => {
+ let r = r.as_primitive::<IntervalDayTimeType>();
+ Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r)))
+ }
+
+ (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => {
+ let r = r.as_primitive::<IntervalMonthDayNanoType>();
+ Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r)))
+ }
+ (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => {
+ let r = r.as_primitive::<IntervalMonthDayNanoType>();
+ Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r)))
+ }
+
+ _ => Err(ArrowError::InvalidArgumentError(format!(
+ "Invalid date arithmetic operation: {} {op:?} {}",
+ l.data_type(),
+ r.data_type()
+ ))),
+ }
+}
+
+/// Perform arithmetic operation on decimal arrays
+fn decimal_op<T: DecimalType>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
+ let l = l.as_primitive::<T>();
+ let r = r.as_primitive::<T>();
+
+ let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) {
+ (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1,
s1, p2, s2),
+ (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1,
s1, p2, s2),
+ _ => unreachable!(),
+ };
+
+ // Follow the Hive decimal arithmetic rules
+ //
https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ let array: PrimitiveArray<T> = match op {
+ Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => {
+ // max(s1, s2)
+ let result_scale = *s1.max(s2);
+
+ // max(s1, s2) + max(p1-s1, p2-s2) + 1
+ let result_precision =
+ (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 -
s2)) as u8)
+ .saturating_add(1)
+ .min(T::MAX_PRECISION);
+
+ let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale -
s1) as _);
+ let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale -
s2) as _);
+
+ match op {
+ Op::Add | Op::AddWrapping => {
+ try_op!(
+ l,
+ l_s,
+ r,
+ r_s,
+
l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?)
+ )
+ }
+ Op::Sub | Op::SubWrapping => {
+ try_op!(
+ l,
+ l_s,
+ r,
+ r_s,
+
l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?)
+ )
+ }
+ _ => unreachable!(),
+ }
+ .with_precision_and_scale(result_precision, result_scale)?
+ }
+ Op::Mul | Op::MulWrapping => {
+ let result_precision = p1.saturating_add(p2 +
1).min(T::MAX_PRECISION);
+ let result_scale = s1.saturating_add(*s2);
+ if result_scale > T::MAX_SCALE {
+ // SQL standard says that if the resulting scale of a multiply
operation goes
+ // beyond the maximum, rounding is not acceptable and thus an
error occurs
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Output scale of {} {op:?} {} would exceed max scale of
{}",
+ l.data_type(),
+ r.data_type(),
+ T::MAX_SCALE
+ )));
+ }
+
+ try_op!(l, l_s, r, r_s, l.mul_checked(r))
+ .with_precision_and_scale(result_precision, result_scale)?
+ }
+
+ Op::Div => {
+ // Follow postgres and MySQL adding a fixed scale increment of 4
+ // s1 + 4
+ let result_scale = s1.saturating_add(4).min(T::MAX_SCALE);
+ let mul_pow = result_scale - s1 + s2;
+
+ // p1 - s1 + s2 + result_scale
+ let result_precision =
+ (mul_pow.saturating_add(*p1 as i8) as
u8).min(T::MAX_PRECISION);
+
+ let (l_mul, r_mul) = match mul_pow.cmp(&0) {
+ Ordering::Greater => (
+ T::Native::usize_as(10).pow_wrapping(mul_pow as _),
+ T::Native::ONE,
+ ),
+ Ordering::Equal => (T::Native::ONE, T::Native::ONE),
+ Ordering::Less => (
+ T::Native::ONE,
+
T::Native::usize_as(10).pow_wrapping(mul_pow.neg_wrapping() as _),
+ ),
+ };
+
+ try_op!(
+ l,
+ l_s,
+ r,
+ r_s,
+ l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?)
+ )
+ .with_precision_and_scale(result_precision, result_scale)?
+ }
+
+ Op::Rem => {
+ // max(s1, s2)
+ let result_scale = *s1.max(s2);
+ // min(p1-s1, p2 -s2) + max( s1,s2 )
+ let result_precision =
+ (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 -
s2)) as u8)
+ .min(T::MAX_PRECISION);
+
+ let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale -
s1) as _);
+ let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale -
s2) as _);
+
+ try_op!(
+ l,
+ l_s,
+ r,
+ r_s,
+ l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?)
+ )
+ .with_precision_and_scale(result_precision, result_scale)?
+ }
+ };
+
+ Ok(Arc::new(array))
+}
diff --git a/arrow-array/src/array/primitive_array.rs
b/arrow-array/src/array/primitive_array.rs
index 576f645b0..833732637 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -517,6 +517,15 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
Self::try_new(values, nulls).unwrap()
}
+ /// Create a new [`PrimitiveArray`] of the given length where all values
are null
+ pub fn new_null(length: usize) -> Self {
+ Self {
+ data_type: T::DATA_TYPE,
+ values: vec![T::Native::usize_as(0); length].into(),
+ nulls: Some(NullBuffer::new_null(length)),
+ }
+ }
+
/// Create a new [`PrimitiveArray`] from the provided values and nulls
///
/// # Errors
diff --git a/arrow-array/src/scalar.rs b/arrow-array/src/scalar.rs
index e54a999f9..c142107c5 100644
--- a/arrow-array/src/scalar.rs
+++ b/arrow-array/src/scalar.rs
@@ -92,6 +92,12 @@ impl Datum for dyn Array {
}
}
+impl Datum for &dyn Array {
+ fn get(&self) -> (&dyn Array, bool) {
+ (*self, false)
+ }
+}
+
/// A wrapper around a single value [`Array`] indicating kernels should treat
it as a scalar value
///
/// See [`Datum`] for more information
diff --git a/arrow/benches/arithmetic_kernels.rs
b/arrow/benches/arithmetic_kernels.rs
index 4ed197783..e982b0eb4 100644
--- a/arrow/benches/arithmetic_kernels.rs
+++ b/arrow/benches/arithmetic_kernels.rs
@@ -15,65 +15,61 @@
// specific language governing permissions and limitations
// under the License.
-#[macro_use]
-extern crate criterion;
-use criterion::Criterion;
-use rand::Rng;
+use criterion::*;
extern crate arrow;
+use arrow::compute::kernels::numeric::*;
use arrow::datatypes::Float32Type;
use arrow::util::bench_util::*;
-use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng};
+use arrow_array::Scalar;
fn add_benchmark(c: &mut Criterion) {
const BATCH_SIZE: usize = 64 * 1024;
for null_density in [0., 0.1, 0.5, 0.9, 1.0] {
let arr_a = create_primitive_array::<Float32Type>(BATCH_SIZE,
null_density);
let arr_b = create_primitive_array::<Float32Type>(BATCH_SIZE,
null_density);
- let scalar = seedable_rng().gen();
+ let scalar_a = create_primitive_array::<Float32Type>(1, 0.);
+ let scalar = Scalar::new(&scalar_a);
c.bench_function(&format!("add({null_density})"), |b| {
- b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap()))
+ b.iter(|| criterion::black_box(add_wrapping(&arr_a,
&arr_b).unwrap()))
});
c.bench_function(&format!("add_checked({null_density})"), |b| {
- b.iter(|| criterion::black_box(add_checked(&arr_a,
&arr_b).unwrap()))
+ b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("add_scalar({null_density})"), |b| {
- b.iter(|| criterion::black_box(add_scalar(&arr_a,
scalar).unwrap()))
+ b.iter(|| criterion::black_box(add_wrapping(&arr_a,
&scalar).unwrap()))
});
c.bench_function(&format!("subtract({null_density})"), |b| {
- b.iter(|| criterion::black_box(subtract(&arr_a, &arr_b).unwrap()))
+ b.iter(|| criterion::black_box(sub_wrapping(&arr_a,
&arr_b).unwrap()))
});
c.bench_function(&format!("subtract_checked({null_density})"), |b| {
- b.iter(|| criterion::black_box(subtract_checked(&arr_a,
&arr_b).unwrap()))
+ b.iter(|| criterion::black_box(sub(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("subtract_scalar({null_density})"), |b| {
- b.iter(|| criterion::black_box(subtract_scalar(&arr_a,
scalar).unwrap()))
+ b.iter(|| criterion::black_box(sub_wrapping(&arr_a,
&scalar).unwrap()))
});
c.bench_function(&format!("multiply({null_density})"), |b| {
- b.iter(|| criterion::black_box(multiply(&arr_a, &arr_b).unwrap()))
+ b.iter(|| criterion::black_box(mul_wrapping(&arr_a,
&arr_b).unwrap()))
});
c.bench_function(&format!("multiply_checked({null_density})"), |b| {
- b.iter(|| criterion::black_box(multiply_checked(&arr_a,
&arr_b).unwrap()))
+ b.iter(|| criterion::black_box(mul(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("multiply_scalar({null_density})"), |b| {
- b.iter(|| criterion::black_box(multiply_scalar(&arr_a,
scalar).unwrap()))
+ b.iter(|| criterion::black_box(mul_wrapping(&arr_a,
&scalar).unwrap()))
});
c.bench_function(&format!("divide({null_density})"), |b| {
- b.iter(|| criterion::black_box(divide(&arr_a, &arr_b).unwrap()))
- });
- c.bench_function(&format!("divide_checked({null_density})"), |b| {
- b.iter(|| criterion::black_box(divide_checked(&arr_a,
&arr_b).unwrap()))
+ b.iter(|| criterion::black_box(div(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("divide_scalar({null_density})"), |b| {
- b.iter(|| criterion::black_box(divide_scalar(&arr_a,
scalar).unwrap()))
+ b.iter(|| criterion::black_box(div(&arr_a, &scalar).unwrap()))
});
c.bench_function(&format!("modulo({null_density})"), |b| {
- b.iter(|| criterion::black_box(modulus(&arr_a, &arr_b).unwrap()))
+ b.iter(|| criterion::black_box(rem(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("modulo_scalar({null_density})"), |b| {
- b.iter(|| criterion::black_box(modulus_scalar(&arr_a,
scalar).unwrap()))
+ b.iter(|| criterion::black_box(rem(&arr_a, &scalar).unwrap()))
});
}
}
diff --git a/arrow/src/compute/kernels/mod.rs b/arrow/src/compute/kernels/mod.rs
index d9c948c60..49eae6d3a 100644
--- a/arrow/src/compute/kernels/mod.rs
+++ b/arrow/src/compute/kernels/mod.rs
@@ -19,7 +19,9 @@
pub mod limit;
-pub use arrow_arith::{aggregate, arithmetic, arity, bitwise, boolean,
temporal};
+pub use arrow_arith::{
+ aggregate, arithmetic, arity, bitwise, boolean, numeric, temporal,
+};
pub use arrow_cast::cast;
pub use arrow_cast::parse as cast_utils;
pub use arrow_ord::{partition, sort};
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index 12aa1309c..a392d1dee 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -105,6 +105,8 @@ To export an array, create an `ArrowArray` using
[ArrowArray::try_new].
use std::{mem::size_of, ptr::NonNull, sync::Arc};
+pub use arrow_data::ffi::FFI_ArrowArray;
+pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags};
use arrow_schema::UnionMode;
use crate::array::{layout, ArrayData};
@@ -113,9 +115,6 @@ use crate::datatypes::DataType;
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
-pub use arrow_data::ffi::FFI_ArrowArray;
-pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags};
-
// returns the number of bits that buffer `i` (in the C data interface) is
expected to have.
// This is set by the Arrow specification
fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
@@ -412,7 +411,16 @@ impl<'a> ArrowArray<'a> {
#[cfg(test)]
mod tests {
- use super::*;
+ use std::collections::HashMap;
+ use std::convert::TryFrom;
+ use std::mem::ManuallyDrop;
+ use std::ptr::addr_of_mut;
+
+ use arrow_array::builder::UnionBuilder;
+ use arrow_array::cast::AsArray;
+ use arrow_array::types::{Float64Type, Int32Type};
+ use arrow_array::{StructArray, UnionArray};
+
use crate::array::{
make_array, Array, ArrayData, BooleanArray, Decimal128Array,
DictionaryArray,
DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray,
@@ -421,14 +429,8 @@ mod tests {
};
use crate::compute::kernels;
use crate::datatypes::{Field, Int8Type};
- use arrow_array::builder::UnionBuilder;
- use arrow_array::cast::AsArray;
- use arrow_array::types::{Float64Type, Int32Type};
- use arrow_array::{StructArray, UnionArray};
- use std::collections::HashMap;
- use std::convert::TryFrom;
- use std::mem::ManuallyDrop;
- use std::ptr::addr_of_mut;
+
+ use super::*;
#[test]
fn test_round_trip() {
@@ -440,10 +442,10 @@ mod tests {
// (simulate consumer) import it
let array = Int32Array::from(from_ffi(array, &schema).unwrap());
- let array = kernels::arithmetic::add(&array, &array).unwrap();
+ let array = kernels::numeric::add(&array, &array).unwrap();
// verify
- assert_eq!(array, Int32Array::from(vec![2, 4, 6]));
+ assert_eq!(array.as_ref(), &Int32Array::from(vec![2, 4, 6]));
}
#[test]
@@ -491,10 +493,10 @@ mod tests {
let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(array, &Int32Array::from(vec![Some(2), None]));
- let array = kernels::arithmetic::add(array, array).unwrap();
+ let array = kernels::numeric::add(array, array).unwrap();
// verify
- assert_eq!(array, Int32Array::from(vec![Some(4), None]));
+ assert_eq!(array.as_ref(), &Int32Array::from(vec![Some(4), None]));
// (drop/release)
Ok(())