This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new da3879ef4 Add dictionary support to subtract_scalar, multiply_scalar,
divide_scalar (#2020)
da3879ef4 is described below
commit da3879ef407eb5818a6ef537e14fac56fdd8a69b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Jul 7 00:29:16 2022 -0700
Add dictionary support to subtract_scalar, multiply_scalar, divide_scalar
(#2020)
---
arrow/src/compute/kernels/arithmetic.rs | 171 ++++++++++++++++++++++++++++++++
1 file changed, 171 insertions(+)
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index f64038c19..9b860feee 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -760,6 +760,21 @@ where
Ok(unary(array, |value| value - scalar))
}
+/// Subtract every value in an array by a scalar. If any value in the array is
null then the
+/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) ->
Result<ArrayRef>
+where
+ T: datatypes::ArrowNumericType,
+ T::Native: Add<Output = T::Native>
+ + Sub<Output = T::Native>
+ + Mul<Output = T::Native>
+ + Div<Output = T::Native>
+ + Zero,
+{
+ unary_dyn::<_, T>(array, |value| value - scalar)
+}
+
/// Perform `-` operation on an array. If value is null then the result is
also null.
pub fn negate<T>(array: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
where
@@ -824,6 +839,23 @@ where
Ok(unary(array, |value| value * scalar))
}
+/// Multiply every value in an array by a scalar. If any value in the array is
null then the
+/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) ->
Result<ArrayRef>
+where
+ T: datatypes::ArrowNumericType,
+ T::Native: Add<Output = T::Native>
+ + Sub<Output = T::Native>
+ + Mul<Output = T::Native>
+ + Div<Output = T::Native>
+ + Rem<Output = T::Native>
+ + Zero
+ + One,
+{
+ unary_dyn::<_, T>(array, |value| value * scalar)
+}
+
/// Perform `left % right` operation on two arrays. If either left or right
value is null
/// then the result is also null. If any right hand value is zero then the
result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
@@ -909,6 +941,21 @@ where
Ok(unary(array, |a| a / divisor))
}
+/// Divide every value in an array by a scalar. If any value in the array is
null then the
+/// result is also null. If the scalar is zero then the result of this
operation will be
+/// `Err(ArrowError::DivideByZero)`. The given array must be a
`PrimitiveArray` of the type
+/// same as the scalar, or a `DictionaryArray` of the value type same as the
scalar.
+pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) ->
Result<ArrayRef>
+where
+ T: datatypes::ArrowNumericType,
+ T::Native: Div<Output = T::Native> + Zero,
+{
+ if divisor.is_zero() {
+ return Err(ArrowError::DivideByZero);
+ }
+ unary_dyn::<_, T>(array, |value| value / divisor)
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1054,6 +1101,46 @@ mod tests {
assert_eq!(10, c.value(4));
}
+ #[test]
+ fn test_primitive_array_subtract_scalar_dyn() {
+ let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None,
Some(9)]);
+ let b = 1_i32;
+ let c = subtract_scalar_dyn::<Int32Type>(&a, b).unwrap();
+ let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert_eq!(4, c.value(0));
+ assert_eq!(5, c.value(1));
+ assert_eq!(6, c.value(2));
+ assert!(c.is_null(3));
+ assert_eq!(8, c.value(4));
+
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+ let mut builder = PrimitiveDictionaryBuilder::new(key_builder,
value_builder);
+ builder.append(5).unwrap();
+ builder.append_null().unwrap();
+ builder.append(7).unwrap();
+ builder.append(8).unwrap();
+ builder.append(9).unwrap();
+ let a = builder.finish();
+ let b = -1_i32;
+
+ let c = subtract_scalar_dyn::<Int32Type>(&a, b).unwrap();
+ let c = c
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap();
+ let values = c
+ .values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<Int32Type>>()
+ .unwrap();
+ assert_eq!(6, values.value(c.key(0).unwrap()));
+ assert!(c.is_null(1));
+ assert_eq!(8, values.value(c.key(2).unwrap()));
+ assert_eq!(9, values.value(c.key(3).unwrap()));
+ assert_eq!(10, values.value(c.key(4).unwrap()));
+ }
+
#[test]
fn test_primitive_array_multiply_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8),
Some(9)]);
@@ -1098,6 +1185,46 @@ mod tests {
assert_eq!(90, c.value(4));
}
+ #[test]
+ fn test_primitive_array_multiply_scalar_dyn() {
+ let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None,
Some(9)]);
+ let b = 2_i32;
+ let c = multiply_scalar_dyn::<Int32Type>(&a, b).unwrap();
+ let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert_eq!(10, c.value(0));
+ assert_eq!(12, c.value(1));
+ assert_eq!(14, c.value(2));
+ assert!(c.is_null(3));
+ assert_eq!(18, c.value(4));
+
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+ let mut builder = PrimitiveDictionaryBuilder::new(key_builder,
value_builder);
+ builder.append(5).unwrap();
+ builder.append_null().unwrap();
+ builder.append(7).unwrap();
+ builder.append(8).unwrap();
+ builder.append(9).unwrap();
+ let a = builder.finish();
+ let b = -1_i32;
+
+ let c = multiply_scalar_dyn::<Int32Type>(&a, b).unwrap();
+ let c = c
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap();
+ let values = c
+ .values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<Int32Type>>()
+ .unwrap();
+ assert_eq!(-5, values.value(c.key(0).unwrap()));
+ assert!(c.is_null(1));
+ assert_eq!(-7, values.value(c.key(2).unwrap()));
+ assert_eq!(-8, values.value(c.key(3).unwrap()));
+ assert_eq!(-9, values.value(c.key(4).unwrap()));
+ }
+
#[test]
fn test_primitive_array_add_sliced() {
let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]);
@@ -1244,6 +1371,50 @@ mod tests {
assert_eq!(c, expected);
}
+ #[test]
+ fn test_primitive_array_divide_scalar_dyn() {
+ let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None,
Some(9)]);
+ let b = 2_i32;
+ let c = divide_scalar_dyn::<Int32Type>(&a, b).unwrap();
+ let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert_eq!(2, c.value(0));
+ assert_eq!(3, c.value(1));
+ assert_eq!(3, c.value(2));
+ assert!(c.is_null(3));
+ assert_eq!(4, c.value(4));
+
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+ let mut builder = PrimitiveDictionaryBuilder::new(key_builder,
value_builder);
+ builder.append(5).unwrap();
+ builder.append_null().unwrap();
+ builder.append(7).unwrap();
+ builder.append(8).unwrap();
+ builder.append(9).unwrap();
+ let a = builder.finish();
+ let b = -2_i32;
+
+ let c = divide_scalar_dyn::<Int32Type>(&a, b).unwrap();
+ let c = c
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap();
+ let values = c
+ .values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<Int32Type>>()
+ .unwrap();
+ assert_eq!(-2, values.value(c.key(0).unwrap()));
+ assert!(c.is_null(1));
+ assert_eq!(-3, values.value(c.key(2).unwrap()));
+ assert_eq!(-4, values.value(c.key(3).unwrap()));
+ assert_eq!(-4, values.value(c.key(4).unwrap()));
+
+ let e = divide_scalar_dyn::<Int32Type>(&a, 0_i32)
+ .expect_err("should have failed due to divide by zero");
+ assert_eq!("DivideByZero", format!("{:?}", e));
+ }
+
#[test]
fn test_primitive_array_divide_scalar_sliced() {
let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);