kszucs commented on a change in pull request #873:
URL: https://github.com/apache/arrow-datafusion/pull/873#discussion_r738589948



##########
File path: python/src/functions.rs
##########
@@ -15,232 +15,210 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::udaf;
-use crate::udf;
-use crate::{expression, types::PyDataType};
-use datafusion::arrow::datatypes::DataType;
-use datafusion::logical_plan::{self, Literal};
-use datafusion::physical_plan::functions::Volatility;
-use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction, Python};
 use std::sync::Arc;
 
-/// Expression representing a column on the existing plan.
-#[pyfunction]
-#[pyo3(text_signature = "(name)")]
-fn col(name: &str) -> expression::Expression {
-    expression::Expression {
-        expr: logical_plan::col(name),
-    }
-}
-
-/// # A bridge type that converts PyAny data into datafusion literal
-///
-/// Note that the ordering here matters because it has to be from
-/// narrow to wider values because Python has duck typing so putting
-/// Int before Boolean results in a premature match.
-#[derive(FromPyObject)]
-enum PythonLiteral<'a> {
-    Boolean(bool),
-    Int(i64),
-    UInt(u64),
-    Float(f64),
-    Str(&'a str),
-    Binary(&'a [u8]),
-}
+use pyo3::{prelude::*, wrap_pyfunction, Python};
 
-impl<'a> Literal for PythonLiteral<'a> {
-    fn lit(&self) -> logical_plan::Expr {
-        match self {
-            PythonLiteral::Boolean(val) => val.lit(),
-            PythonLiteral::Int(val) => val.lit(),
-            PythonLiteral::UInt(val) => val.lit(),
-            PythonLiteral::Float(val) => val.lit(),
-            PythonLiteral::Str(val) => val.lit(),
-            PythonLiteral::Binary(val) => val.lit(),
-        }
-    }
-}
+use datafusion::arrow::datatypes::DataType;
+use datafusion::logical_plan;
+//use datafusion::logical_plan::Expr;
+use datafusion::physical_plan::functions::Volatility;
+use datafusion::physical_plan::{
+    aggregates::AggregateFunction, functions::BuiltinScalarFunction,
+};
 
-/// Expression representing a constant value
-#[pyfunction]
-#[pyo3(text_signature = "(value)")]
-fn lit(value: &PyAny) -> PyResult<expression::Expression> {
-    let py_lit = value.extract::<PythonLiteral>()?;
-    let expr = py_lit.lit();
-    Ok(expression::Expression { expr })
-}
+use crate::{
+    expression::{PyAggregateUDF, PyExpr, PyScalarUDF},
+    udaf, udf,
+};
 
 #[pyfunction]
-fn array(value: Vec<expression::Expression>) -> expression::Expression {
-    expression::Expression {
+fn array(value: Vec<PyExpr>) -> PyExpr {
+    PyExpr {
         expr: logical_plan::array(value.into_iter().map(|x| 
x.expr).collect::<Vec<_>>()),
     }
 }
 
 #[pyfunction]
-fn in_list(
-    expr: expression::Expression,
-    value: Vec<expression::Expression>,
-    negated: bool,
-) -> expression::Expression {
-    expression::Expression {
-        expr: logical_plan::in_list(
-            expr.expr,
-            value.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
-            negated,
-        ),
-    }
+fn in_list(expr: PyExpr, value: Vec<PyExpr>, negated: bool) -> PyExpr {
+    logical_plan::in_list(
+        expr.expr,
+        value.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
+        negated,
+    )
+    .into()
 }
 
 /// Current date and time
 #[pyfunction]
-fn now() -> expression::Expression {
-    expression::Expression {
+fn now() -> PyExpr {
+    PyExpr {
         // here lit(0) is a stub for conform to arity
         expr: logical_plan::now(logical_plan::lit(0)),
     }
 }
 
 /// Returns a random value in the range 0.0 <= x < 1.0
 #[pyfunction]
-fn random() -> expression::Expression {
-    expression::Expression {
+fn random() -> PyExpr {
+    PyExpr {
         expr: logical_plan::random(),
     }
 }
 
 /// Computes a binary hash of the given data. type is the algorithm to use.
 /// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, 
blake2b, and blake3.
 #[pyfunction(value, method)]
-fn digest(
-    value: expression::Expression,
-    method: expression::Expression,
-) -> expression::Expression {
-    expression::Expression {
+fn digest(value: PyExpr, method: PyExpr) -> PyExpr {
+    PyExpr {
         expr: logical_plan::digest(value.expr, method.expr),
     }
 }
 
 /// Concatenates the text representations of all the arguments.
 /// NULL arguments are ignored.
 #[pyfunction(args = "*")]
-fn concat(args: &PyTuple) -> PyResult<expression::Expression> {
-    let expressions = expression::from_tuple(args)?;
-    let args = expressions.into_iter().map(|e| e.expr).collect::<Vec<_>>();
-    Ok(expression::Expression {
-        expr: logical_plan::concat(&args),
-    })
+fn concat(args: Vec<PyExpr>) -> PyResult<PyExpr> {
+    let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
+    Ok(logical_plan::concat(&args).into())
 }
 
 /// Concatenates all but the first argument, with separators.
 /// The first argument is used as the separator string, and should not be NULL.
 /// Other NULL arguments are ignored.
 #[pyfunction(sep, args = "*")]
-fn concat_ws(sep: String, args: &PyTuple) -> PyResult<expression::Expression> {
-    let expressions = expression::from_tuple(args)?;
-    let args = expressions.into_iter().map(|e| e.expr).collect::<Vec<_>>();
-    Ok(expression::Expression {
-        expr: logical_plan::concat_ws(sep, &args),
-    })
+fn concat_ws(sep: String, args: Vec<PyExpr>) -> PyResult<PyExpr> {
+    let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
+    Ok(logical_plan::concat_ws(sep, &args).into())
 }
 
-macro_rules! define_unary_function {
-    ($NAME: ident) => {
-        #[doc = "This function is not documented yet"]
-        #[pyfunction]
-        fn $NAME(value: expression::Expression) -> expression::Expression {
-            expression::Expression {
-                expr: logical_plan::$NAME(value.expr),
-            }
-        }
+macro_rules! scalar_function {
+    ($NAME: ident, $FUNC: ident) => {
+        scalar_function!($NAME, $FUNC, stringify!($NAME));
     };
-    ($NAME: ident, $DOC: expr) => {
+    ($NAME: ident, $FUNC: ident, $DOC: expr) => {
         #[doc = $DOC]
-        #[pyfunction]
-        fn $NAME(value: expression::Expression) -> expression::Expression {
-            expression::Expression {
-                expr: logical_plan::$NAME(value.expr),
-            }
+        #[pyfunction(args = "*")]
+        fn $NAME(args: Vec<PyExpr>) -> PyExpr {
+            let expr = logical_plan::Expr::ScalarFunction {
+                fun: BuiltinScalarFunction::$FUNC,
+                args: args.into_iter().map(|e| e.into()).collect(),
+            };
+            expr.into()
         }
     };
 }
 
-define_unary_function!(sqrt, "sqrt");
-define_unary_function!(sin, "sin");
-define_unary_function!(cos, "cos");
-define_unary_function!(tan, "tan");
-define_unary_function!(asin, "asin");
-define_unary_function!(acos, "acos");
-define_unary_function!(atan, "atan");
-define_unary_function!(floor, "floor");
-define_unary_function!(ceil, "ceil");
-define_unary_function!(round, "round");
-define_unary_function!(trunc, "trunc");
-define_unary_function!(abs, "abs");
-define_unary_function!(signum, "signum");
-define_unary_function!(exp, "exp");
-define_unary_function!(ln, "ln");
-define_unary_function!(log2, "log2");
-define_unary_function!(log10, "log10");
+macro_rules! aggregate_function {
+    ($NAME: ident, $FUNC: ident) => {
+        aggregate_function!($NAME, $FUNC, stringify!($NAME));
+    };
+    ($NAME: ident, $FUNC: ident, $DOC: expr) => {
+        #[doc = $DOC]
+        #[pyfunction(args = "*", distinct = "false")]
+        fn $NAME(args: Vec<PyExpr>, distinct: bool) -> PyExpr {
+            let expr = logical_plan::Expr::AggregateFunction {
+                fun: AggregateFunction::$FUNC,
+                args: args.into_iter().map(|e| e.into()).collect(),
+                distinct,
+            };
+            expr.into()
+        }
+    };
+}
 
-define_unary_function!(ascii, "Returns the numeric code of the first character 
of the argument. In UTF8 encoding, returns the Unicode code point of the 
character. In other multibyte encodings, the argument must be an ASCII 
character.");
-define_unary_function!(sum);
-define_unary_function!(
+scalar_function!(abs, Abs);
+scalar_function!(acos, Acos);
+scalar_function!(ascii, Ascii, "Returns the numeric code of the first 
character of the argument. In UTF8 encoding, returns the Unicode code point of 
the character. In other multibyte encodings, the argument must be an ASCII 
character.");
+scalar_function!(asin, Asin);
+scalar_function!(atan, Atan);
+scalar_function!(
     bit_length,
+    BitLength,
     "Returns number of bits in the string (8 times the octet_length)."
 );
-define_unary_function!(btrim, "Removes the longest string containing only 
characters in characters (a space by default) from the start and end of 
string.");
-define_unary_function!(
+scalar_function!(btrim, Btrim, "Removes the longest string containing only 
characters in characters (a space by default) from the start and end of 
string.");
+scalar_function!(ceil, Ceil);
+scalar_function!(
     character_length,
+    CharacterLength,
     "Returns number of characters in the string."
 );
-define_unary_function!(chr, "Returns the character with the given code.");
-define_unary_function!(initcap, "Converts the first letter of each word to 
upper case and the rest to lower case. Words are sequences of alphanumeric 
characters separated by non-alphanumeric characters.");
-define_unary_function!(left, "Returns first n characters in the string, or 
when n is negative, returns all but last |n| characters.");
-define_unary_function!(lower, "Converts the string to all lower case");
-define_unary_function!(lpad, "Extends the string to length length by 
prepending the characters fill (a space by default). If the string is already 
longer than length then it is truncated (on the right).");
-define_unary_function!(ltrim, "Removes the longest string containing only 
characters in characters (a space by default) from the start of string.");
-define_unary_function!(
+scalar_function!(chr, Chr, "Returns the character with the given code.");
+scalar_function!(cos, Cos);
+scalar_function!(exp, Exp);
+scalar_function!(floor, Floor);
+scalar_function!(initcap, InitCap, "Converts the first letter of each word to 
upper case and the rest to lower case. Words are sequences of alphanumeric 
characters separated by non-alphanumeric characters.");
+scalar_function!(left, Left, "Returns first n characters in the string, or 
when n is negative, returns all but last |n| characters.");
+scalar_function!(ln, Ln);
+scalar_function!(log10, Log10);
+scalar_function!(log2, Log2);
+scalar_function!(lower, Lower, "Converts the string to all lower case");
+scalar_function!(lpad, Lpad, "Extends the string to length length by 
prepending the characters fill (a space by default). If the string is already 
longer than length then it is truncated (on the right).");
+scalar_function!(ltrim, Ltrim, "Removes the longest string containing only 
characters in characters (a space by default) from the start of string.");
+scalar_function!(
     md5,
+    MD5,
     "Computes the MD5 hash of the argument, with the result written in 
hexadecimal."
 );
-define_unary_function!(octet_length, "Returns number of bytes in the string. 
Since this version of the function accepts type character directly, it will not 
strip trailing spaces.");
-define_unary_function!(
-    replace,
-    "Replaces all occurrences in string of substring from with substring to."
-);
-define_unary_function!(repeat, "Repeats string the specified number of 
times.");
-define_unary_function!(
+scalar_function!(octet_length, OctetLength, "Returns number of bytes in the 
string. Since this version of the function accepts type character directly, it 
will not strip trailing spaces.");
+scalar_function!(regexp_match, RegexpMatch);
+scalar_function!(
     regexp_replace,
+    RegexpReplace,
     "Replaces substring(s) matching a POSIX regular expression"
 );
-define_unary_function!(
+scalar_function!(
+    repeat,
+    Repeat,
+    "Repeats string the specified number of times."
+);
+scalar_function!(
+    replace,
+    Replace,
+    "Replaces all occurrences in string of substring from with substring to."
+);
+scalar_function!(
     reverse,
+    Reverse,
     "Reverses the order of the characters in the string."
 );
-define_unary_function!(right, "Returns last n characters in the string, or 
when n is negative, returns all but first |n| characters.");
-define_unary_function!(rpad, "Extends the string to length length by appending 
the characters fill (a space by default). If the string is already longer than 
length then it is truncated.");
-define_unary_function!(rtrim, "Removes the longest string containing only 
characters in characters (a space by default) from the end of string.");
-define_unary_function!(sha224);
-define_unary_function!(sha256);
-define_unary_function!(sha384);
-define_unary_function!(sha512);
-define_unary_function!(split_part, "Splits string at occurrences of delimiter 
and returns the n'th field (counting from one).");
-define_unary_function!(starts_with, "Returns true if string starts with 
prefix.");
-define_unary_function!(strpos,"Returns starting index of specified substring 
within string, or zero if it's not present. (Same as position(substring in 
string), but note the reversed argument order.)");
-define_unary_function!(substr);
-define_unary_function!(
+scalar_function!(right, Right, "Returns last n characters in the string, or 
when n is negative, returns all but first |n| characters.");
+scalar_function!(round, Round);
+scalar_function!(rpad, Rpad, "Extends the string to length length by appending 
the characters fill (a space by default). If the string is already longer than 
length then it is truncated.");
+scalar_function!(rtrim, Rtrim, "Removes the longest string containing only 
characters in characters (a space by default) from the end of string.");
+scalar_function!(sha224, SHA224);
+scalar_function!(sha256, SHA256);
+scalar_function!(sha384, SHA384);
+scalar_function!(sha512, SHA512);
+scalar_function!(signum, Signum);
+scalar_function!(sin, Sin);
+scalar_function!(split_part, SplitPart, "Splits string at occurrences of 
delimiter and returns the n'th field (counting from one).");
+scalar_function!(sqrt, Sqrt);
+scalar_function!(
+    starts_with,
+    StartsWith,
+    "Returns true if string starts with prefix."
+);
+scalar_function!(strpos, Strpos, "Returns starting index of specified 
substring within string, or zero if it's not present. (Same as 
position(substring in string), but note the reversed argument order.)");
+scalar_function!(substr, Substr);
+scalar_function!(tan, Tan);
+scalar_function!(
     to_hex,
+    ToHex,
     "Converts the number to its equivalent hexadecimal representation."
 );
-define_unary_function!(translate, "Replaces each character in string that 
matches a character in the from set with the corresponding character in the to 
set. If from is longer than to, occurrences of the extra characters in from are 
deleted.");
-define_unary_function!(trim, "Removes the longest string containing only 
characters in characters (a space by default) from the start, end, or both ends 
(BOTH is the default) of string.");
-define_unary_function!(upper, "Converts the string to all upper case.");
-define_unary_function!(avg);
-define_unary_function!(min);
-define_unary_function!(max);
-define_unary_function!(count);
-define_unary_function!(approx_distinct);
+scalar_function!(to_timestamp, ToTimestamp);
+scalar_function!(translate, Translate, "Replaces each character in string that 
matches a character in the from set with the corresponding character in the to 
set. If from is longer than to, occurrences of the extra characters in from are 
deleted.");
+scalar_function!(trim, Trim, "Removes the longest string containing only 
characters in characters (a space by default) from the start, end, or both ends 
(BOTH is the default) of string.");
+scalar_function!(trunc, Trunc);
+scalar_function!(upper, Upper, "Converts the string to all upper case.");
+
+aggregate_function!(avg, Avg);
+aggregate_function!(count, Count);
+aggregate_function!(max, Max);
+aggregate_function!(min, Min);
+aggregate_function!(sum, Sum);

Review comment:
       Nice catch! 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to