gstvg commented on code in PR #18921:
URL: https://github.com/apache/datafusion/pull/18921#discussion_r2976118931


##########
datafusion/expr/src/udlf.rs:
##########
@@ -0,0 +1,550 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`LambdaUDF`]: Lambda User Defined Functions
+
+use crate::expr::schema_name_from_exprs_comma_separated_without_space;
+use crate::{ColumnarValue, Documentation, Expr};
+use arrow::array::{ArrayRef, RecordBatch};
+use arrow::datatypes::{DataType, Field, FieldRef, Schema};
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::{Result, ScalarValue, not_impl_err};
+use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
+use datafusion_expr_common::signature::Volatility;
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+use std::any::Any;
+use std::cmp::Ordering;
+use std::fmt::Debug;
+use std::hash::{Hash, Hasher};
+use std::sync::Arc;
+
+/// The types of arguments for which a function has implementations.
+///
+/// [`LambdaTypeSignature`] **DOES NOT** define the types that a user query 
could call the
+/// function with. DataFusion will automatically coerce (cast) argument types 
to
+/// one of the supported function signatures, if possible.
+///
+/// # Overview
+/// Functions typically provide implementations for a small number of different
+/// argument [`DataType`]s, rather than all possible combinations. If a user
+/// calls a function with arguments that do not match any of the declared 
types,
+/// DataFusion will attempt to automatically coerce (add casts to) function
+/// arguments so they match the [`LambdaTypeSignature`]. See the 
[`type_coercion`] module
+/// for more details
+///
+/// [`type_coercion`]: crate::type_coercion
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub enum LambdaTypeSignature {
+    /// The acceptable signature and coercions rules are special for this
+    /// function.
+    ///
+    /// If this signature is specified,
+    /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare 
argument types.
+    UserDefined,
+    /// One or more lambdas or arguments with arbitrary types
+    VariadicAny,
+    /// The specified number of lambdas or arguments with arbitrary types.
+    Any(usize),
+}
+
+/// Provides information necessary for calling a lambda function.
+///
+/// - [`LambdaTypeSignature`] defines the argument types that a function has 
implementations
+///   for.
+///
+/// - [`Volatility`] defines how the output of the function changes with the 
input.
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub struct LambdaSignature {
+    /// The data types that the function accepts. See [LambdaTypeSignature] 
for more information.
+    pub type_signature: LambdaTypeSignature,
+    /// The volatility of the function. See [Volatility] for more information.
+    pub volatility: Volatility,
+}
+
+impl LambdaSignature {
+    /// Creates a new `LambdaSignature` from a given type signature and 
volatility.
+    pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> 
Self {
+        LambdaSignature {
+            type_signature,
+            volatility,
+        }
+    }
+
+    /// User-defined coercion rules for the function.
+    pub fn user_defined(volatility: Volatility) -> Self {
+        Self {
+            type_signature: LambdaTypeSignature::UserDefined,
+            volatility,
+        }
+    }
+
+    /// An arbitrary number of lambdas or arguments of any type.
+    pub fn variadic_any(volatility: Volatility) -> Self {
+        Self {
+            type_signature: LambdaTypeSignature::VariadicAny,
+            volatility,
+        }
+    }
+
+    /// A specified number of arguments of any type
+    pub fn any(arg_count: usize, volatility: Volatility) -> Self {
+        Self {
+            type_signature: LambdaTypeSignature::Any(arg_count),
+            volatility,
+        }
+    }
+}
+
+impl PartialEq for dyn LambdaUDF {
+    fn eq(&self, other: &Self) -> bool {
+        self.dyn_eq(other.as_any())
+    }
+}
+
+impl PartialOrd for dyn LambdaUDF {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        let mut cmp = self.name().cmp(other.name());
+        if cmp == Ordering::Equal {
+            cmp = self.signature().partial_cmp(other.signature())?;
+        }
+        if cmp == Ordering::Equal {
+            cmp = self.aliases().partial_cmp(other.aliases())?;
+        }
+        // Contract for PartialOrd and PartialEq consistency requires that
+        // a == b if and only if partial_cmp(a, b) == Some(Equal).
+        if cmp == Ordering::Equal && self != other {
+            // Functions may have other properties besides name and signature
+            // that differentiate two instances (e.g. type, or arbitrary 
parameters).
+            // We cannot return Some(Equal) in such case.
+            return None;
+        }
+        debug_assert!(
+            cmp == Ordering::Equal || self != other,
+            "Detected incorrect implementation of PartialEq when comparing 
functions: '{}' and '{}'. \
+            The functions compare as equal, but they are not equal based on 
general properties that \
+            the PartialOrd implementation observes,",
+            self.name(),
+            other.name()
+        );
+        Some(cmp)
+    }
+}
+
+impl Eq for dyn LambdaUDF {}
+
+impl Hash for dyn LambdaUDF {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        self.dyn_hash(state)
+    }
+}
+
+/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a
+/// lambda function.
+#[derive(Debug, Clone)]
+pub struct LambdaFunctionArgs {
+    /// The evaluated arguments and lambdas to the function
+    pub args: Vec<ValueOrLambda<ColumnarValue, LambdaArgument>>,
+    /// Field associated with each arg, if it exists
+    /// For lambdas, it will be the field of the result of
+    /// the lambda if evaluated with the parameters
+    /// returned from [`LambdaUDF::lambdas_parameters`]
+    pub arg_fields: Vec<ValueOrLambda<FieldRef, FieldRef>>,
+    /// The number of rows in record batch being evaluated
+    pub number_rows: usize,
+    /// The return field of the lambda function returned
+    /// (from `return_field_from_args`) when creating the
+    /// physical expression from the logical expression
+    pub return_field: FieldRef,
+    /// The config options at execution time
+    pub config_options: Arc<ConfigOptions>,
+}
+
+impl LambdaFunctionArgs {
+    /// The return type of the function. See [`Self::return_field`] for more
+    /// details.
+    pub fn return_type(&self) -> &DataType {
+        self.return_field.data_type()
+    }
+}
+
+/// A lambda argument to a LambdaFunction
+#[derive(Clone, Debug)]
+pub struct LambdaArgument {
+    /// The parameters defined in this lambda
+    ///
+    /// For example, for `array_transform([2], v -> -v)`,
+    /// this will be `vec![Field::new("v", DataType::Int32, true)]`
+    params: Vec<FieldRef>,
+    /// The body of the lambda
+    ///
+    /// For example, for `array_transform([2], v -> -v)`,
+    /// this will be the physical expression of `-v`
+    body: Arc<dyn PhysicalExpr>,
+}
+
+impl LambdaArgument {
+    pub fn new(params: Vec<FieldRef>, body: Arc<dyn PhysicalExpr>) -> Self {
+        Self { params, body }
+    }
+
+    /// Evaluate this lambda
+    /// `args` should evaluate to the value of each parameter
+    /// of the correspondent lambda returned in 
[LambdaUDF::lambdas_parameters].
+    pub fn evaluate(
+        &self,
+        args: &[&dyn Fn() -> Result<ArrayRef>],
+    ) -> Result<ColumnarValue> {
+        let columns = args
+            .iter()
+            .take(self.params.len())
+            .map(|arg| arg())
+            .collect::<Result<_>>()?;
+
+        let schema = Arc::new(Schema::new(self.params.clone()));
+
+        let batch = RecordBatch::try_new(schema, columns)?;
+
+        self.body.evaluate(&batch)
+    }
+}
+
+/// Information about arguments passed to the function
+///
+/// This structure contains metadata about how the function was called
+/// such as the type of the arguments, any scalar arguments and if the
+/// arguments can (ever) be null
+///
+/// See [`LambdaUDF::return_field_from_args`] for more information
+#[derive(Clone, Debug)]
+pub struct LambdaReturnFieldArgs<'a> {
+    /// The data types of the arguments to the function
+    ///
+    /// If argument `i` to the function is a lambda, it will be the field of 
the result of the
+    /// lambda if evaluated with the parameters returned from 
[`LambdaUDF::lambdas_parameters`]
+    ///
+    /// For example, with `array_transform([1], v -> v == 5)`
+    /// this field will be `[
+    ///     ValueOrLambda::Value(Field::new("", 
DataType::List(DataType::Int32), false)),
+    ///     ValueOrLambda::Lambda(Field::new("", DataType::Boolean, false))
+    /// ]`
+    pub arg_fields: &'a [ValueOrLambda<FieldRef, FieldRef>],
+    /// Is argument `i` to the function a scalar (constant)?
+    ///
+    /// If the argument `i` is not a scalar, it will be None
+    ///
+    /// For example, if a function is called like `array_transform([1], v -> v 
== 5)`
+    /// this field will be `[Some(ScalarValue::List(...), None]`
+    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
+}
+
+/// An argument to a lambda function
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum ValueOrLambda<V, L> {
+    /// A value with associated data
+    Value(V),
+    /// A lambda with associated data
+    Lambda(L),
+}
+
+/// Trait for implementing user defined lambda functions.
+///
+/// This trait exposes the full API for implementing user defined functions and
+/// can be used to implement any function.
+///
+/// See [`array_transform.rs`] for a commented complete implementation
+///
+/// [`array_transform.rs`]: 
https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs
+pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync {
+    /// Returns this object as an [`Any`] trait object
+    fn as_any(&self) -> &dyn Any;
+
+    /// Returns this function's name
+    fn name(&self) -> &str;
+
+    /// Returns any aliases (alternate names) for this function.
+    ///
+    /// Aliases can be used to invoke the same function using different names.
+    /// For example in some databases `now()` and `current_timestamp()` are
+    /// aliases for the same function. This behavior can be obtained by
+    /// returning `current_timestamp` as an alias for the `now` function.
+    ///
+    /// Note: `aliases` should only include names other than [`Self::name`].
+    /// Defaults to `[]` (no aliases)
+    fn aliases(&self) -> &[String] {
+        &[]
+    }
+
+    /// Returns the name of the column this expression would create
+    ///
+    /// See [`Expr::schema_name`] for details
+    fn schema_name(&self, args: &[Expr]) -> Result<String> {
+        Ok(format!(
+            "{}({})",
+            self.name(),
+            schema_name_from_exprs_comma_separated_without_space(args)?
+        ))
+    }
+
+    /// Returns a [`LambdaSignature`] describing the argument types for which 
this
+    /// function has an implementation, and the function's [`Volatility`].
+    ///
+    /// See [`LambdaSignature`] for more details on argument type handling
+    /// and [`Self::return_field_from_args`] for computing the return type.
+    ///
+    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
+    fn signature(&self) -> &LambdaSignature;
+
+    /// Returns a list of the same size as args where each value is the logic 
below applied to value at the correspondent position in args:
+    ///
+    /// If it's a value, return None
+    /// If it's a lambda, return the list of all parameters that that lambda 
supports
+    ///
+    /// Example for array_transform:
+    ///
+    /// `array_transform([2.0, 8.0], v -> v > 4.0)`
+    ///
+    /// ```ignore
+    /// let lambdas_parameters = array_transform.lambdas_parameters(&[
+    ///      ValueOrLambdaParameter::Value(Field::new("", 
DataType::new_list(DataType::Float32, false)))]), // the Field of the literal 
`[2, 8]`
+    ///      ValueOrLambdaParameter::Lambda, // A lambda

Review Comment:
   I will apply your suggestions regarding lambdas_parameters later and this 
will become a vec of Field's, thanks



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to