alamb commented on code in PR #6527:
URL: https://github.com/apache/arrow-datafusion/pull/6527#discussion_r1218479117


##########
datafusion/expr/src/signature.rs:
##########
@@ -42,10 +42,10 @@ pub enum TypeSignature {
     /// arbitrary number of arguments of an common type out of a list of valid 
types
     // A function such as `concat` is `Variadic(vec![DataType::Utf8, 
DataType::LargeUtf8])`
     Variadic(Vec<DataType>),
-    /// arbitrary number of arguments of an arbitrary but equal type
+    /// arbitrary number of arguments of an equal type
     // A function such as `array` is `VariadicEqual`
     // The first argument decides the type used for coercion
-    VariadicEqual,
+    VariadicEqual(Vec<DataType>),

Review Comment:
   If the types are all equal, what is the purpose of storing a Vec of them? or 
maybe the Vec is the list of valid types?
   
   It also looks like the code special cases when there are no types specified 
to mean "any type is allowed" -- can we please explicitly mention that in the 
docs as well? Or maybe add a new variant (`VariadicEqualSpecific`?)
   
   



##########
datafusion/physical-expr/src/comparison_expressions.rs:
##########
@@ -0,0 +1,268 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Some of these functions reference the Postgres documentation
+// or implementation to ensure compatibility and are subject to
+// the Postgres license.
+
+//! Comparison expressions
+
+use arrow::array::Array;
+use arrow::datatypes::DataType;
+use arrow_ord::comparison::{gt_dyn, lt_dyn};
+use arrow_select::zip::zip;
+use datafusion_common::scalar::ScalarValue;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::ColumnarValue;
+
+#[derive(Debug, Clone, PartialEq)]
+enum ComparisonOperator {
+    Greatest,
+    Least,
+}
+
+macro_rules! compare_scalar_typed {
+    ($op:expr, $args:expr, $data_type:ident) => {{
+        let value = $args
+            .iter()
+            .filter_map(|scalar| match scalar {
+                ScalarValue::$data_type(v) => v.clone(),
+                _ => panic!("Impossibly got non-scalar values"),
+            })
+            .reduce(|a, b| match $op {
+                ComparisonOperator::Greatest => a.max(b),
+                ComparisonOperator::Least => a.min(b),
+            });
+        ScalarValue::$data_type(value)
+    }};
+}
+
+/// Evaluate a greatest or least function for the case when all arguments are 
scalars
+fn compare_scalars(
+    data_type: DataType,
+    op: ComparisonOperator,
+    args: &[ScalarValue],
+) -> ScalarValue {
+    match data_type {
+        DataType::Boolean => compare_scalar_typed!(op, args, Boolean),
+        DataType::Int8 => compare_scalar_typed!(op, args, Int8),
+        DataType::Int16 => compare_scalar_typed!(op, args, Int16),
+        DataType::Int32 => compare_scalar_typed!(op, args, Int32),
+        DataType::Int64 => compare_scalar_typed!(op, args, Int64),
+        DataType::UInt8 => compare_scalar_typed!(op, args, UInt8),
+        DataType::UInt16 => compare_scalar_typed!(op, args, UInt16),
+        DataType::UInt32 => compare_scalar_typed!(op, args, UInt32),
+        DataType::UInt64 => compare_scalar_typed!(op, args, UInt64),
+        DataType::Float32 => compare_scalar_typed!(op, args, Float32),
+        DataType::Float64 => compare_scalar_typed!(op, args, Float64),
+        DataType::Utf8 => compare_scalar_typed!(op, args, Utf8),
+        DataType::LargeUtf8 => compare_scalar_typed!(op, args, LargeUtf8),
+        _ => panic!("Unsupported data type for comparison: {:?}", data_type),
+    }
+}
+
+/// Evaluate a greatest or least function
+fn compare(op: ComparisonOperator, args: &[ColumnarValue]) -> 
Result<ColumnarValue> {

Review Comment:
   I think this functionality is basically the same as the Min/Max accumulators 
(e.g. 
https://docs.rs/datafusion/latest/datafusion/physical_plan/expressions/struct.MinAccumulator.html)
   
   



##########
datafusion/physical-expr/src/comparison_expressions.rs:
##########
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Some of these functions reference the Postgres documentation
+// or implementation to ensure compatibility and are subject to
+// the Postgres license.
+
+//! Comparison expressions
+
+use arrow::datatypes::DataType;
+use datafusion_common::scalar::ScalarValue;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::ColumnarValue;
+
+#[derive(Debug, Clone, PartialEq)]
+enum ComparisonOperator {
+    Greatest,
+    Least,
+}
+
+macro_rules! compare_scalar_typed {
+    ($op:expr, $args:expr, $data_type:ident) => {{
+        let value = $args
+            .iter()
+            .filter_map(|scalar| match scalar {
+                ScalarValue::$data_type(v) => v.clone(),
+                _ => panic!("Impossibly got non-scalar values"),
+            })
+            .reduce(|a, b| match $op {
+                ComparisonOperator::Greatest => a.max(b),
+                ComparisonOperator::Least => a.min(b),
+            });
+        ScalarValue::$data_type(value)
+    }};
+}
+
+/// Evaluate a greatest or least function for the case when all arguments are 
scalars
+fn compare_scalars(
+    data_type: DataType,
+    op: ComparisonOperator,
+    args: &[ScalarValue],
+) -> ScalarValue {
+    match data_type {
+        DataType::Boolean => compare_scalar_typed!(op, args, Boolean),
+        DataType::Int8 => compare_scalar_typed!(op, args, Int8),
+        DataType::Int16 => compare_scalar_typed!(op, args, Int16),
+        DataType::Int32 => compare_scalar_typed!(op, args, Int32),
+        DataType::Int64 => compare_scalar_typed!(op, args, Int64),
+        DataType::UInt8 => compare_scalar_typed!(op, args, UInt8),
+        DataType::UInt16 => compare_scalar_typed!(op, args, UInt16),
+        DataType::UInt32 => compare_scalar_typed!(op, args, UInt32),
+        DataType::UInt64 => compare_scalar_typed!(op, args, UInt64),
+        DataType::Float32 => compare_scalar_typed!(op, args, Float32),
+        DataType::Float64 => compare_scalar_typed!(op, args, Float64),
+        DataType::Utf8 => compare_scalar_typed!(op, args, Utf8),
+        DataType::LargeUtf8 => compare_scalar_typed!(op, args, LargeUtf8),
+        _ => panic!("Unsupported data type for comparison: {:?}", data_type),
+    }
+}
+
+/// Evaluate a greatest or least function
+fn compare(op: ComparisonOperator, args: &[ColumnarValue]) -> 
Result<ColumnarValue> {
+    if args.is_empty() {
+        return Err(DataFusionError::Internal(format!(
+            "{:?} expressions require at least one argument",
+            op
+        )));
+    } else if args.len() == 1 {
+        return Ok(args[0].clone());
+    }
+
+    let args_types = args
+        .iter()
+        .map(|arg| match arg {
+            ColumnarValue::Array(array) => array.data_type().clone(),
+            ColumnarValue::Scalar(scalar) => scalar.get_datatype(),
+        })
+        .collect::<Vec<_>>();
+
+    if args_types.iter().any(|t| t != &args_types[0]) {
+        return Err(DataFusionError::Internal(format!(
+            "{:?} expressions require all arguments to be of the same type",
+            op
+        )));
+    }
+
+    let all_scalars = args.iter().all(|arg| match arg {
+        ColumnarValue::Array(_) => false,
+        ColumnarValue::Scalar(_) => true,
+    });
+
+    if all_scalars {
+        let args: Vec<_> = args
+            .iter()
+            .map(|arg| match arg {
+                ColumnarValue::Array(_) => {
+                    panic!("Internal error: all arguments should be scalars")
+                }
+                ColumnarValue::Scalar(scalar) => scalar.clone(),
+            })
+            .collect();
+        Ok(ColumnarValue::Scalar(compare_scalars(
+            args_types[0].clone(),
+            op,
+            &args,
+        )))
+    } else {
+        Err(DataFusionError::NotImplemented(format!(
+            "{:?} expressions are not implemented for arrays yet as we need to 
update arrow kernels",

Review Comment:
   There is `min` like https://docs.rs/arrow/latest/arrow/compute/fn.min.html 
but I think it works only for primitive arrays
   
   Perhaps you can use the existing MinAccumulator / Max Accumulators? 
   
   
https://docs.rs/datafusion/latest/datafusion/physical_plan/expressions/struct.MinAccumulator.html
   
   



##########
datafusion/core/tests/sql/expr.rs:
##########
@@ -200,6 +200,41 @@ async fn binary_bitwise_shift() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_comparison_func_expressions() -> Result<()> {
+    test_expression!("greatest(1,2,3)", "3");
+    test_expression!("least(1,2,3)", "1");
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_comparison_func_array_scalar_expression() -> Result<()> {

Review Comment:
   Can we please write this test using `sqllogictest` instead of a new rs test? 
I think you'll find it quite nice for sql tests and it is easier to maintain 
and extend. 
   
   
https://github.com/apache/arrow-datafusion/tree/main/datafusion/core/tests/sqllogictests#cookbook-adding-tests
   



##########
datafusion/expr/src/type_coercion/functions.rs:
##########
@@ -236,6 +256,58 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_get_valid_types_variadic_equal() -> Result<()> {
+        let signature = TypeSignature::VariadicEqual(vec![DataType::Int32]);
+
+        let valid_types = get_valid_types(
+            &signature,
+            &[DataType::Int32, DataType::Int32, DataType::Int32],
+        )?;
+        assert_eq!(valid_types.len(), 1);
+        assert_eq!(
+            valid_types[0],
+            vec![DataType::Int32, DataType::Int32, DataType::Int32]
+        );
+
+        // invalid case, with int and boolean
+        let is_error = get_valid_types(
+            &signature,
+            &[DataType::Int32, DataType::Boolean, DataType::Int32],
+        )
+        .is_err();
+        assert!(is_error);

Review Comment:
   Another way to test for error is doing:
   
   ```suggestion
           let is_error = get_valid_types(
               &signature,
               &[DataType::Int32, DataType::Boolean, DataType::Int32],
           )
           .uwrap_err();
   ```
   
   Which is both more concise, but I think also prints out what the `OK` value 
is when the result is not an Error



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to