alamb commented on a change in pull request #8660: URL: https://github.com/apache/arrow/pull/8660#discussion_r523416360
########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -969,6 +975,42 @@ macro_rules! compute_utf8_op { }}; } +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_utf8_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + if let ScalarValue::Utf8(Some(string_value)) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}( + &ll, + &string_value, + )?)) + } else { + Err(ExecutionError::General(format!( + "compute_utf8_op_scalar failed to cast literal value {}", Review comment: ```suggestion "internal error: compute_utf8_op_scalar failed to cast literal value {}", ``` The point being that if this code is hit it isn't likely a bug in how someone is using datafusion, it is a bug in datafusion itself. ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -1571,24 +1754,6 @@ impl Literal { } } -/// Build array containing the same literal value repeated. This is necessary because the Arrow Review comment: ❤️ ########## File path: rust/datafusion/src/physical_plan/mod.rs ########## @@ -100,6 +100,30 @@ pub enum Distribution { SinglePartition, } +/// Represents the result from an expression +pub enum ColumnarValue { Review comment: I think starting with `ColumnValue` in DataFusion and then hoisting it out into `arrow` makes a lot of sense ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -1664,8 +1791,18 @@ pub struct PhysicalSortExpr { impl PhysicalSortExpr { /// evaluate the sort expression into SortColumn that can be passed into arrow sort kernel pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result<SortColumn> { + let values_to_sort = self.expr.evaluate(batch)?; + let array_to_sort = match values_to_sort { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => { + return Err(ExecutionError::General(format!( Review comment: again, I like this approach -- we should be removing scalar values out of Sort exprs in the planner, not during execution ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -1288,18 +1363,84 @@ impl PhysicalExpr for BinaryExpr { Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?) } - fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> { - let left = self.left.evaluate(batch)?; - let right = self.right.evaluate(batch)?; - if left.data_type() != right.data_type() { + fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { + let left_value = self.left.evaluate(batch)?; + let right_value = self.right.evaluate(batch)?; + let left_data_type = left_value.data_type(); + let right_data_type = right_value.data_type(); + + if left_data_type != right_data_type { return Err(ExecutionError::General(format!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", - self.op, - left.data_type(), - right.data_type() + self.op, left_data_type, right_data_type ))); } - match &self.op { + + let scalar_result = match (&left_value, &right_value) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { + // if left is array and right is literal - use scalar operations + let result: Result<ArrayRef> = match &self.op { + Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt), + Operator::LtEq => { + binary_array_op_scalar!(array, scalar.clone(), lt_eq) + } + Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt), + Operator::GtEq => { + binary_array_op_scalar!(array, scalar.clone(), gt_eq) + } + Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::NotEq => { + binary_array_op_scalar!(array, scalar.clone(), neq) + } + _ => Err(ExecutionError::General(format!( + "Scalar values on right side of operator {} are not supported", + self.op + ))), + }; + Some(result) + } + (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => { + // if right is literal and left is array - reverse operator and parameters + let result: Result<ArrayRef> = match &self.op { + Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), gt), + Operator::LtEq => { + binary_array_op_scalar!(array, scalar.clone(), gt_eq) + } + Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), lt), + Operator::GtEq => { + binary_array_op_scalar!(array, scalar.clone(), lt_eq) + } + Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::NotEq => { + binary_array_op_scalar!(array, scalar.clone(), neq) + } + _ => Err(ExecutionError::General(format!( + "Scalar values on left side of operator {} are not supported", + self.op + ))), + }; + Some(result) + } + (_, _) => None, + }; + + if let Some(result) = scalar_result { + return result.map(|a| ColumnarValue::Array(a)); + } + + let (left, right) = match (left_value, right_value) { + // if both arrays - extract and continue execution + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => (left, right), + // if both literals - not supported Review comment: I think this is fine -- we should handle such `scalar` op `scalar` things in the planner / optimizer, in my opinion ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -1288,18 +1363,84 @@ impl PhysicalExpr for BinaryExpr { Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?) } - fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> { - let left = self.left.evaluate(batch)?; - let right = self.right.evaluate(batch)?; - if left.data_type() != right.data_type() { + fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { + let left_value = self.left.evaluate(batch)?; + let right_value = self.right.evaluate(batch)?; + let left_data_type = left_value.data_type(); + let right_data_type = right_value.data_type(); + + if left_data_type != right_data_type { return Err(ExecutionError::General(format!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", - self.op, - left.data_type(), - right.data_type() + self.op, left_data_type, right_data_type ))); } - match &self.op { + + let scalar_result = match (&left_value, &right_value) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { + // if left is array and right is literal - use scalar operations + let result: Result<ArrayRef> = match &self.op { + Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt), + Operator::LtEq => { + binary_array_op_scalar!(array, scalar.clone(), lt_eq) + } + Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt), + Operator::GtEq => { + binary_array_op_scalar!(array, scalar.clone(), gt_eq) + } + Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::NotEq => { + binary_array_op_scalar!(array, scalar.clone(), neq) + } + _ => Err(ExecutionError::General(format!( + "Scalar values on right side of operator {} are not supported", + self.op + ))), + }; + Some(result) + } + (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => { + // if right is literal and left is array - reverse operator and parameters + let result: Result<ArrayRef> = match &self.op { Review comment: Another structure would be to normalize the invocations by finding the array, and the literal and then have a single call site for invoking the comparison Like turning both `array` > `lit_1` and `lit_1` < `array` into A = `array` lit = `lit_` op = `>` However this involves changing the comparison ops and I am not sure I can claim the code would be any simpler / potentially less bug prone. ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -969,6 +975,42 @@ macro_rules! compute_utf8_op { }}; } +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_utf8_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + if let ScalarValue::Utf8(Some(string_value)) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}( + &ll, + &string_value, + )?)) + } else { + Err(ExecutionError::General(format!( + "compute_utf8_op_scalar failed to cast literal value {}", + $RIGHT + ))) + } + }}; +} + +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + use std::convert::TryInto; + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( Review comment: ```suggestion // generate the scalar function name, such as lt_scalar, from the $OP parameter // (which could have a value of lt) and the suffix _scalar Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( ``` ########## File path: rust/datafusion/src/scalar.rs ########## @@ -115,22 +121,32 @@ impl ScalarValue { /// Converts a scalar value into an 1-row array. pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts a scalar value into an 1-row array. Review comment: ```suggestion /// Converts a scalar value into an array of `size` rows. ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org