kszucs commented on a change in pull request #873:
URL: https://github.com/apache/arrow-datafusion/pull/873#discussion_r738589948
##########
File path: python/src/functions.rs
##########
@@ -15,232 +15,210 @@
// specific language governing permissions and limitations
// under the License.
-use crate::udaf;
-use crate::udf;
-use crate::{expression, types::PyDataType};
-use datafusion::arrow::datatypes::DataType;
-use datafusion::logical_plan::{self, Literal};
-use datafusion::physical_plan::functions::Volatility;
-use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction, Python};
use std::sync::Arc;
-/// Expression representing a column on the existing plan.
-#[pyfunction]
-#[pyo3(text_signature = "(name)")]
-fn col(name: &str) -> expression::Expression {
- expression::Expression {
- expr: logical_plan::col(name),
- }
-}
-
-/// # A bridge type that converts PyAny data into datafusion literal
-///
-/// Note that the ordering here matters because it has to be from
-/// narrow to wider values because Python has duck typing so putting
-/// Int before Boolean results in a premature match.
-#[derive(FromPyObject)]
-enum PythonLiteral<'a> {
- Boolean(bool),
- Int(i64),
- UInt(u64),
- Float(f64),
- Str(&'a str),
- Binary(&'a [u8]),
-}
+use pyo3::{prelude::*, wrap_pyfunction, Python};
-impl<'a> Literal for PythonLiteral<'a> {
- fn lit(&self) -> logical_plan::Expr {
- match self {
- PythonLiteral::Boolean(val) => val.lit(),
- PythonLiteral::Int(val) => val.lit(),
- PythonLiteral::UInt(val) => val.lit(),
- PythonLiteral::Float(val) => val.lit(),
- PythonLiteral::Str(val) => val.lit(),
- PythonLiteral::Binary(val) => val.lit(),
- }
- }
-}
+use datafusion::arrow::datatypes::DataType;
+use datafusion::logical_plan;
+//use datafusion::logical_plan::Expr;
+use datafusion::physical_plan::functions::Volatility;
+use datafusion::physical_plan::{
+ aggregates::AggregateFunction, functions::BuiltinScalarFunction,
+};
-/// Expression representing a constant value
-#[pyfunction]
-#[pyo3(text_signature = "(value)")]
-fn lit(value: &PyAny) -> PyResult<expression::Expression> {
- let py_lit = value.extract::<PythonLiteral>()?;
- let expr = py_lit.lit();
- Ok(expression::Expression { expr })
-}
+use crate::{
+ expression::{PyAggregateUDF, PyExpr, PyScalarUDF},
+ udaf, udf,
+};
#[pyfunction]
-fn array(value: Vec<expression::Expression>) -> expression::Expression {
- expression::Expression {
+fn array(value: Vec<PyExpr>) -> PyExpr {
+ PyExpr {
expr: logical_plan::array(value.into_iter().map(|x|
x.expr).collect::<Vec<_>>()),
}
}
#[pyfunction]
-fn in_list(
- expr: expression::Expression,
- value: Vec<expression::Expression>,
- negated: bool,
-) -> expression::Expression {
- expression::Expression {
- expr: logical_plan::in_list(
- expr.expr,
- value.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
- negated,
- ),
- }
+fn in_list(expr: PyExpr, value: Vec<PyExpr>, negated: bool) -> PyExpr {
+ logical_plan::in_list(
+ expr.expr,
+ value.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
+ negated,
+ )
+ .into()
}
/// Current date and time
#[pyfunction]
-fn now() -> expression::Expression {
- expression::Expression {
+fn now() -> PyExpr {
+ PyExpr {
// here lit(0) is a stub for conform to arity
expr: logical_plan::now(logical_plan::lit(0)),
}
}
/// Returns a random value in the range 0.0 <= x < 1.0
#[pyfunction]
-fn random() -> expression::Expression {
- expression::Expression {
+fn random() -> PyExpr {
+ PyExpr {
expr: logical_plan::random(),
}
}
/// Computes a binary hash of the given data. type is the algorithm to use.
/// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s,
blake2b, and blake3.
#[pyfunction(value, method)]
-fn digest(
- value: expression::Expression,
- method: expression::Expression,
-) -> expression::Expression {
- expression::Expression {
+fn digest(value: PyExpr, method: PyExpr) -> PyExpr {
+ PyExpr {
expr: logical_plan::digest(value.expr, method.expr),
}
}
/// Concatenates the text representations of all the arguments.
/// NULL arguments are ignored.
#[pyfunction(args = "*")]
-fn concat(args: &PyTuple) -> PyResult<expression::Expression> {
- let expressions = expression::from_tuple(args)?;
- let args = expressions.into_iter().map(|e| e.expr).collect::<Vec<_>>();
- Ok(expression::Expression {
- expr: logical_plan::concat(&args),
- })
+fn concat(args: Vec<PyExpr>) -> PyResult<PyExpr> {
+ let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
+ Ok(logical_plan::concat(&args).into())
}
/// Concatenates all but the first argument, with separators.
/// The first argument is used as the separator string, and should not be NULL.
/// Other NULL arguments are ignored.
#[pyfunction(sep, args = "*")]
-fn concat_ws(sep: String, args: &PyTuple) -> PyResult<expression::Expression> {
- let expressions = expression::from_tuple(args)?;
- let args = expressions.into_iter().map(|e| e.expr).collect::<Vec<_>>();
- Ok(expression::Expression {
- expr: logical_plan::concat_ws(sep, &args),
- })
+fn concat_ws(sep: String, args: Vec<PyExpr>) -> PyResult<PyExpr> {
+ let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
+ Ok(logical_plan::concat_ws(sep, &args).into())
}
-macro_rules! define_unary_function {
- ($NAME: ident) => {
- #[doc = "This function is not documented yet"]
- #[pyfunction]
- fn $NAME(value: expression::Expression) -> expression::Expression {
- expression::Expression {
- expr: logical_plan::$NAME(value.expr),
- }
- }
+macro_rules! scalar_function {
+ ($NAME: ident, $FUNC: ident) => {
+ scalar_function!($NAME, $FUNC, stringify!($NAME));
};
- ($NAME: ident, $DOC: expr) => {
+ ($NAME: ident, $FUNC: ident, $DOC: expr) => {
#[doc = $DOC]
- #[pyfunction]
- fn $NAME(value: expression::Expression) -> expression::Expression {
- expression::Expression {
- expr: logical_plan::$NAME(value.expr),
- }
+ #[pyfunction(args = "*")]
+ fn $NAME(args: Vec<PyExpr>) -> PyExpr {
+ let expr = logical_plan::Expr::ScalarFunction {
+ fun: BuiltinScalarFunction::$FUNC,
+ args: args.into_iter().map(|e| e.into()).collect(),
+ };
+ expr.into()
}
};
}
-define_unary_function!(sqrt, "sqrt");
-define_unary_function!(sin, "sin");
-define_unary_function!(cos, "cos");
-define_unary_function!(tan, "tan");
-define_unary_function!(asin, "asin");
-define_unary_function!(acos, "acos");
-define_unary_function!(atan, "atan");
-define_unary_function!(floor, "floor");
-define_unary_function!(ceil, "ceil");
-define_unary_function!(round, "round");
-define_unary_function!(trunc, "trunc");
-define_unary_function!(abs, "abs");
-define_unary_function!(signum, "signum");
-define_unary_function!(exp, "exp");
-define_unary_function!(ln, "ln");
-define_unary_function!(log2, "log2");
-define_unary_function!(log10, "log10");
+macro_rules! aggregate_function {
+ ($NAME: ident, $FUNC: ident) => {
+ aggregate_function!($NAME, $FUNC, stringify!($NAME));
+ };
+ ($NAME: ident, $FUNC: ident, $DOC: expr) => {
+ #[doc = $DOC]
+ #[pyfunction(args = "*", distinct = "false")]
+ fn $NAME(args: Vec<PyExpr>, distinct: bool) -> PyExpr {
+ let expr = logical_plan::Expr::AggregateFunction {
+ fun: AggregateFunction::$FUNC,
+ args: args.into_iter().map(|e| e.into()).collect(),
+ distinct,
+ };
+ expr.into()
+ }
+ };
+}
-define_unary_function!(ascii, "Returns the numeric code of the first character
of the argument. In UTF8 encoding, returns the Unicode code point of the
character. In other multibyte encodings, the argument must be an ASCII
character.");
-define_unary_function!(sum);
-define_unary_function!(
+scalar_function!(abs, Abs);
+scalar_function!(acos, Acos);
+scalar_function!(ascii, Ascii, "Returns the numeric code of the first
character of the argument. In UTF8 encoding, returns the Unicode code point of
the character. In other multibyte encodings, the argument must be an ASCII
character.");
+scalar_function!(asin, Asin);
+scalar_function!(atan, Atan);
+scalar_function!(
bit_length,
+ BitLength,
"Returns number of bits in the string (8 times the octet_length)."
);
-define_unary_function!(btrim, "Removes the longest string containing only
characters in characters (a space by default) from the start and end of
string.");
-define_unary_function!(
+scalar_function!(btrim, Btrim, "Removes the longest string containing only
characters in characters (a space by default) from the start and end of
string.");
+scalar_function!(ceil, Ceil);
+scalar_function!(
character_length,
+ CharacterLength,
"Returns number of characters in the string."
);
-define_unary_function!(chr, "Returns the character with the given code.");
-define_unary_function!(initcap, "Converts the first letter of each word to
upper case and the rest to lower case. Words are sequences of alphanumeric
characters separated by non-alphanumeric characters.");
-define_unary_function!(left, "Returns first n characters in the string, or
when n is negative, returns all but last |n| characters.");
-define_unary_function!(lower, "Converts the string to all lower case");
-define_unary_function!(lpad, "Extends the string to length length by
prepending the characters fill (a space by default). If the string is already
longer than length then it is truncated (on the right).");
-define_unary_function!(ltrim, "Removes the longest string containing only
characters in characters (a space by default) from the start of string.");
-define_unary_function!(
+scalar_function!(chr, Chr, "Returns the character with the given code.");
+scalar_function!(cos, Cos);
+scalar_function!(exp, Exp);
+scalar_function!(floor, Floor);
+scalar_function!(initcap, InitCap, "Converts the first letter of each word to
upper case and the rest to lower case. Words are sequences of alphanumeric
characters separated by non-alphanumeric characters.");
+scalar_function!(left, Left, "Returns first n characters in the string, or
when n is negative, returns all but last |n| characters.");
+scalar_function!(ln, Ln);
+scalar_function!(log10, Log10);
+scalar_function!(log2, Log2);
+scalar_function!(lower, Lower, "Converts the string to all lower case");
+scalar_function!(lpad, Lpad, "Extends the string to length length by
prepending the characters fill (a space by default). If the string is already
longer than length then it is truncated (on the right).");
+scalar_function!(ltrim, Ltrim, "Removes the longest string containing only
characters in characters (a space by default) from the start of string.");
+scalar_function!(
md5,
+ MD5,
"Computes the MD5 hash of the argument, with the result written in
hexadecimal."
);
-define_unary_function!(octet_length, "Returns number of bytes in the string.
Since this version of the function accepts type character directly, it will not
strip trailing spaces.");
-define_unary_function!(
- replace,
- "Replaces all occurrences in string of substring from with substring to."
-);
-define_unary_function!(repeat, "Repeats string the specified number of
times.");
-define_unary_function!(
+scalar_function!(octet_length, OctetLength, "Returns number of bytes in the
string. Since this version of the function accepts type character directly, it
will not strip trailing spaces.");
+scalar_function!(regexp_match, RegexpMatch);
+scalar_function!(
regexp_replace,
+ RegexpReplace,
"Replaces substring(s) matching a POSIX regular expression"
);
-define_unary_function!(
+scalar_function!(
+ repeat,
+ Repeat,
+ "Repeats string the specified number of times."
+);
+scalar_function!(
+ replace,
+ Replace,
+ "Replaces all occurrences in string of substring from with substring to."
+);
+scalar_function!(
reverse,
+ Reverse,
"Reverses the order of the characters in the string."
);
-define_unary_function!(right, "Returns last n characters in the string, or
when n is negative, returns all but first |n| characters.");
-define_unary_function!(rpad, "Extends the string to length length by appending
the characters fill (a space by default). If the string is already longer than
length then it is truncated.");
-define_unary_function!(rtrim, "Removes the longest string containing only
characters in characters (a space by default) from the end of string.");
-define_unary_function!(sha224);
-define_unary_function!(sha256);
-define_unary_function!(sha384);
-define_unary_function!(sha512);
-define_unary_function!(split_part, "Splits string at occurrences of delimiter
and returns the n'th field (counting from one).");
-define_unary_function!(starts_with, "Returns true if string starts with
prefix.");
-define_unary_function!(strpos,"Returns starting index of specified substring
within string, or zero if it's not present. (Same as position(substring in
string), but note the reversed argument order.)");
-define_unary_function!(substr);
-define_unary_function!(
+scalar_function!(right, Right, "Returns last n characters in the string, or
when n is negative, returns all but first |n| characters.");
+scalar_function!(round, Round);
+scalar_function!(rpad, Rpad, "Extends the string to length length by appending
the characters fill (a space by default). If the string is already longer than
length then it is truncated.");
+scalar_function!(rtrim, Rtrim, "Removes the longest string containing only
characters in characters (a space by default) from the end of string.");
+scalar_function!(sha224, SHA224);
+scalar_function!(sha256, SHA256);
+scalar_function!(sha384, SHA384);
+scalar_function!(sha512, SHA512);
+scalar_function!(signum, Signum);
+scalar_function!(sin, Sin);
+scalar_function!(split_part, SplitPart, "Splits string at occurrences of
delimiter and returns the n'th field (counting from one).");
+scalar_function!(sqrt, Sqrt);
+scalar_function!(
+ starts_with,
+ StartsWith,
+ "Returns true if string starts with prefix."
+);
+scalar_function!(strpos, Strpos, "Returns starting index of specified
substring within string, or zero if it's not present. (Same as
position(substring in string), but note the reversed argument order.)");
+scalar_function!(substr, Substr);
+scalar_function!(tan, Tan);
+scalar_function!(
to_hex,
+ ToHex,
"Converts the number to its equivalent hexadecimal representation."
);
-define_unary_function!(translate, "Replaces each character in string that
matches a character in the from set with the corresponding character in the to
set. If from is longer than to, occurrences of the extra characters in from are
deleted.");
-define_unary_function!(trim, "Removes the longest string containing only
characters in characters (a space by default) from the start, end, or both ends
(BOTH is the default) of string.");
-define_unary_function!(upper, "Converts the string to all upper case.");
-define_unary_function!(avg);
-define_unary_function!(min);
-define_unary_function!(max);
-define_unary_function!(count);
-define_unary_function!(approx_distinct);
+scalar_function!(to_timestamp, ToTimestamp);
+scalar_function!(translate, Translate, "Replaces each character in string that
matches a character in the from set with the corresponding character in the to
set. If from is longer than to, occurrences of the extra characters in from are
deleted.");
+scalar_function!(trim, Trim, "Removes the longest string containing only
characters in characters (a space by default) from the start, end, or both ends
(BOTH is the default) of string.");
+scalar_function!(trunc, Trunc);
+scalar_function!(upper, Upper, "Converts the string to all upper case.");
+
+aggregate_function!(avg, Avg);
+aggregate_function!(count, Count);
+aggregate_function!(max, Max);
+aggregate_function!(min, Min);
+aggregate_function!(sum, Sum);
Review comment:
Nice catch!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]