alamb commented on code in PR #3705:
URL: https://github.com/apache/arrow-datafusion/pull/3705#discussion_r988356383
##########
datafusion/physical-expr/src/aggregate/min_max.rs:
##########
@@ -296,41 +290,18 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
_ => min_max_batch!(values, max),
})
}
-macro_rules! typed_min_max_decimal {
- ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident,
$OP:ident) => {{
- ScalarValue::$SCALAR(
- match ($VALUE, $DELTA) {
- (None, None) => None,
- (Some(a), None) => Some(a.clone()),
- (None, Some(b)) => Some(b.clone()),
- (Some(a), Some(b)) => Some((*a).$OP(*b)),
- },
- $PRECISION.clone(),
- $SCALE.clone(),
- )
- }};
-}
// min/max of two non-string scalar values.
macro_rules! typed_min_max {
- ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
- ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
- (None, None) => None,
- (Some(a), None) => Some(a.clone()),
- (None, Some(b)) => Some(b.clone()),
- (Some(a), Some(b)) => Some((*a).$OP(*b)),
- })
- }};
-
- ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
+ ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(,
$EXTRA_ARGS:ident)*) => {{
Review Comment:
these are drive by cleanups to reduce duplication in the macros, right?
(BTW if you submit these as individual free standing PRs you might find the
reviews are faster -- finding enough contiguous time to review large PR changes
can be challenging at times)
##########
datafusion/common/src/scalar.rs:
##########
@@ -312,155 +312,186 @@ impl Eq for ScalarValue {}
// TODO implement this in arrow-rs with simd
// https://github.com/apache/arrow-rs/issues/1010
macro_rules! decimal_op {
- ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr,
$OPERATION:tt ) => {{
- let (difference, side) = if $LHS_SCALE > $RHS_SCALE {
- ($LHS_SCALE - $RHS_SCALE, true)
- } else {
- ($RHS_SCALE - $LHS_SCALE, false)
- };
- let scale = max($LHS_SCALE, $RHS_SCALE);
- match ($LHS, $RHS, difference) {
- (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale),
- (None, Some(rhs_value), 0) => ScalarValue::Decimal128(Some((0 as i128)
$OPERATION rhs_value), $PRECISION, scale),
- (None, Some(rhs_value), _) => {
- let mut new_value = ((0 as i128) $OPERATION rhs_value);
- if side {
- new_value *= 10_i128.pow((difference) as u32)
- };
- ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
- }
- (Some(lhs_value), None, 0) => ScalarValue::Decimal128(Some(lhs_value
$OPERATION (0 as i128)), $PRECISION, scale),
- (Some(lhs_value), None, _) => {
- let mut new_value = (lhs_value $OPERATION (0 as i128));
- if !!!side {
- new_value *= 10_i128.pow((difference) as u32)
+ ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr,
$OPERATION:tt) => {{
+ let (difference, side) = if $LHS_SCALE > $RHS_SCALE {
+ ($LHS_SCALE - $RHS_SCALE, true)
+ } else {
+ ($RHS_SCALE - $LHS_SCALE, false)
+ };
+ let scale = max($LHS_SCALE, $RHS_SCALE);
+ Ok(match ($LHS, $RHS, difference) {
+ (None, None, _) => ScalarValue::Decimal128(None, $PRECISION,
scale),
+ (lhs, None, 0) => ScalarValue::Decimal128(*lhs, $PRECISION, scale),
+ (Some(lhs_value), None, _) => {
+ let mut new_value = *lhs_value;
+ if !side {
+ new_value *= 10_i128.pow(difference as u32)
+ }
+ ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
}
- ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
- }
- (Some(lhs_value), Some(rhs_value), 0) => {
- ScalarValue::Decimal128(Some(lhs_value $OPERATION rhs_value),
$PRECISION, scale)
- }
- (Some(lhs_value), Some(rhs_value), _) => {
- let new_value = if side {
- rhs_value * 10_i128.pow((difference) as u32) $OPERATION
lhs_value
- } else {
- lhs_value * 10_i128.pow((difference) as u32) $OPERATION
rhs_value
- };
- ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
- }
- }}
+ (None, Some(rhs_value), 0) => {
+ let value = decimal_right!(*rhs_value, $OPERATION);
+ ScalarValue::Decimal128(Some(value), $PRECISION, scale)
+ }
+ (None, Some(rhs_value), _) => {
+ let mut new_value = decimal_right!(*rhs_value, $OPERATION);
+ if side {
+ new_value *= 10_i128.pow(difference as u32)
+ };
+ ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
+ }
+ (Some(lhs_value), Some(rhs_value), 0) => {
+ decimal_binary_op!(lhs_value, rhs_value, $OPERATION,
$PRECISION, scale)
+ }
+ (Some(lhs_value), Some(rhs_value), _) => {
+ let (left_arg, right_arg) = if side {
+ (*lhs_value, rhs_value * 10_i128.pow(difference as u32))
+ } else {
+ (lhs_value * 10_i128.pow(difference as u32), *rhs_value)
+ };
+ decimal_binary_op!(left_arg, right_arg, $OPERATION,
$PRECISION, scale)
+ }
+ })
+ }};
+}
- }
+macro_rules! decimal_binary_op {
+ ($LHS:expr, $RHS:expr, $OPERATION:tt, $PRECISION:expr, $SCALE:expr) => {
+ // TODO: This simple implementation loses precision for calculations
like
+ // multiplication and division. Improve this implementation for
such
+ // operations.
+ ScalarValue::Decimal128(Some($LHS $OPERATION $RHS), $PRECISION, $SCALE)
+ };
}
-// Returns the result of applying operation to two scalar values, including
coercion into $TYPE.
-macro_rules! typed_op {
- ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $TYPE:ident, $OPERATION:tt) => {
- Some(ScalarValue::$SCALAR(match ($LEFT, $RIGHT) {
- (None, None) => None,
- (Some(a), None) => Some((*a as $TYPE) $OPERATION (0 as $TYPE)),
- (None, Some(b)) => Some((0 as $TYPE) $OPERATION (*b as $TYPE)),
- (Some(a), Some(b)) => Some((*a as $TYPE) $OPERATION (*b as $TYPE)),
- }))
+macro_rules! decimal_right {
+ ($TERM:expr, +) => {
+ $TERM
+ };
+ ($TERM:expr, *) => {
+ $TERM
+ };
+ ($TERM:expr, -) => {
+ -$TERM
+ };
+ ($TERM:expr, /) => {
+ Err(DataFusionError::NotImplemented(format!(
+ "Decimal reciprocation not yet supported",
+ )))
};
}
-macro_rules! impl_common_symmetric_cases_op {
- ($LHS:expr, $RHS:expr, $OPERATION:tt, [$([$L_TYPE:ident, $R_TYPE:ident,
$O_TYPE:ident, $O_PRIM:ident]),+]) => {
- match ($LHS, $RHS) {
- $(
- (ScalarValue::$L_TYPE(lhs), ScalarValue::$R_TYPE(rhs)) => {
- typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION)
- }
- (ScalarValue::$R_TYPE(lhs), ScalarValue::$L_TYPE(rhs)) => {
- typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION)
- }
- )+
- _ => None
+// Returns the result of applying operation to two scalar values.
+macro_rules! primitive_op {
+ ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $OPERATION:tt) => {
+ match ($LEFT, $RIGHT) {
+ (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)),
+ #[allow(unused_variables)]
+ (None, Some(b)) => { primitive_right!(*b, $OPERATION, $SCALAR) },
+ (Some(a), Some(b)) => Ok(ScalarValue::$SCALAR(Some(*a $OPERATION
*b))),
}
- }
+ };
}
-macro_rules! impl_common_cases_op {
+macro_rules! primitive_right {
+ ($TERM:expr, +, $SCALAR:ident) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM)))
+ };
+ ($TERM:expr, *, $SCALAR:ident) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM)))
+ };
+ ($TERM:expr, -, UInt64) => {
+ unsigned_subtraction_error!("UInt64")
+ };
+ ($TERM:expr, -, UInt32) => {
+ unsigned_subtraction_error!("UInt32")
+ };
+ ($TERM:expr, -, UInt16) => {
+ unsigned_subtraction_error!("UInt16")
+ };
+ ($TERM:expr, -, UInt8) => {
+ unsigned_subtraction_error!("UInt8")
+ };
+ ($TERM:expr, -, $SCALAR:ident) => {
+ Ok(ScalarValue::$SCALAR(Some(-$TERM)))
+ };
+ ($TERM:expr, /, Float64) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM.recip())))
+ };
+ ($TERM:expr, /, Float32) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM.recip())))
+ };
+ ($TERM:expr, /, $SCALAR:ident) => {
+ Err(DataFusionError::Internal(format!(
+ "Can not divide an uninitialized value to a non-floating point
value",
+ )))
+ };
+}
+
+macro_rules! unsigned_subtraction_error {
+ ($SCALAR:expr) => {{
+ let msg = format!(
+ "Can not subtract a {} value from an uninitialized value",
+ $SCALAR
+ );
+ Err(DataFusionError::Internal(msg))
+ }};
+}
+
+macro_rules! impl_op {
($LHS:expr, $RHS:expr, $OPERATION:tt) => {
match ($LHS, $RHS) {
(
ScalarValue::Decimal128(v1, p1, s1),
ScalarValue::Decimal128(v2, p2, s2),
) => {
- let max_precision = *p1.max(p2);
- Some(decimal_op!(v1, v2, max_precision, *s1, *s2, $OPERATION))
+ decimal_op!(v1, v2, *p1.max(p2), *s1, *s2, $OPERATION)
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
- typed_op!(lhs, rhs, Float64, f64, $OPERATION)
+ primitive_op!(lhs, rhs, Float64, $OPERATION)
}
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
- typed_op!(lhs, rhs, Float32, f32, $OPERATION)
+ primitive_op!(lhs, rhs, Float32, $OPERATION)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
- typed_op!(lhs, rhs, UInt64, u64, $OPERATION)
+ primitive_op!(lhs, rhs, UInt64, $OPERATION)
}
(ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
- typed_op!(lhs, rhs, Int64, i64, $OPERATION)
+ primitive_op!(lhs, rhs, Int64, $OPERATION)
}
(ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
- typed_op!(lhs, rhs, UInt32, u32, $OPERATION)
+ primitive_op!(lhs, rhs, UInt32, $OPERATION)
}
(ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
- typed_op!(lhs, rhs, Int32, i32, $OPERATION)
+ primitive_op!(lhs, rhs, Int32, $OPERATION)
}
(ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
- typed_op!(lhs, rhs, UInt16, u16, $OPERATION)
+ primitive_op!(lhs, rhs, UInt16, $OPERATION)
}
(ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
- typed_op!(lhs, rhs, Int16, i16, $OPERATION)
+ primitive_op!(lhs, rhs, Int16, $OPERATION)
}
(ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
- typed_op!(lhs, rhs, UInt8, u8, $OPERATION)
+ primitive_op!(lhs, rhs, UInt8, $OPERATION)
}
(ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
- typed_op!(lhs, rhs, Int8, i8, $OPERATION)
+ primitive_op!(lhs, rhs, Int8, $OPERATION)
+ }
+ _ => {
+ impl_distinct_cases_op!($LHS, $RHS, $OPERATION)
}
- _ => impl_common_symmetric_cases_op!(
Review Comment:
👍
##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -206,87 +205,37 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type:
&DataType) -> Result<Scalar
macro_rules! sum_row {
($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{
paste::item! {
- match $DELTA {
- None => {}
- Some(v) => $ACC.[<add_ $TYPE>]($INDEX, *v as $TYPE)
+ if let Some(v) = $DELTA {
+ $ACC.[<add_ $TYPE>]($INDEX, *v)
}
}
}};
}
pub(crate) fn add_to_row(
- dt: &DataType,
index: usize,
accessor: &mut RowAccessor,
s: &ScalarValue,
) -> Result<()> {
- match (dt, s) {
- // float64 coerces everything to f64
- (DataType::Float64, ScalarValue::Float64(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Float32(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int64(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int32(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int16(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int8(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt64(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt32(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt16(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt8(rhs)) => {
+ match s {
Review Comment:
I think this is a good change -- to use the type of the input to the type of
the accumulator.
##########
datafusion/physical-expr/src/aggregate/min_max.rs:
##########
@@ -154,16 +154,10 @@ macro_rules! typed_min_max_batch_string {
// Statically-typed version of min/max(array) -> ScalarValue for non-string
types.
macro_rules! typed_min_max_batch {
- ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
- let array = downcast_value!($VALUES, $ARRAYTYPE);
- let value = compute::$OP(array);
- ScalarValue::$SCALAR(value)
- }};
-
- ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
+ ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(,
$EXTRA_ARGS:ident)*) => {{
Review Comment:
Macro 🧙
##########
datafusion/common/src/scalar.rs:
##########
@@ -938,11 +960,13 @@ impl ScalarValue {
}
pub fn is_unsigned(&self) -> bool {
- let value_type = self.get_datatype();
- value_type == DataType::UInt64
- || value_type == DataType::UInt32
- || value_type == DataType::UInt16
- || value_type == DataType::UInt8
+ matches!(
Review Comment:
👍
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]