This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new e36e8ab Add arrow cast (#962)
e36e8ab is described below
commit e36e8abfb96211003afd5406cf4d8bddaabaad71
Author: kosiew <[email protected]>
AuthorDate: Tue Jan 7 21:28:25 2025 +0800
Add arrow cast (#962)
* feat: add data_type parameter to expr_fn macro for arrow_cast function
* feat: add arrow_cast function to cast expressions to specified data types
* docs: add casting section to user guide with examples for arrow_cast
function
* test: add unit test for arrow_cast function to validate casting to
Float64 and Int32
* fix: update arrow_cast function to accept Expr type for data_type
parameter
* fix: update test_arrow_cast to use literal casting for data types
* fix: update arrow_cast function to accept string type for data_type
parameter
* fix: update arrow_cast function to accept Expr type for data_type
parameter
* fix: update test_arrow_cast to use literal for data type parameters
* fix: update arrow_cast function to use arg_1 for datatype parameter
* fix: update arrow_cast function to accept string type for data_type
parameter
* Revert "fix: update arrow_cast function to accept string type for
data_type parameter"
This reverts commit eba0d320820e8f3f9688781f27b2a5579c0e9949.
* fix: update test_arrow_cast to cast literals to string type for
arrow_cast function
* Revert "fix: update test_arrow_cast to cast literals to string type for
arrow_cast function"
This reverts commit 856ff8c4cad0075c282089b5368a7c3fd17f03d8.
* fix: update arrow_cast function to accept string type for data_type
parameter
* Revert "fix: update arrow_cast function to accept string type for
data_type parameter"
This reverts commit 9e1ced7fb56c8aec47bc9f540ea5686c7246f022.
* fix: add utf8_literal function to create UTF8 literal expressions in tests
* Revert "fix: add utf8_literal function to create UTF8 literal expressions
in tests"
This reverts commit 11ed6749e02ab7b34d47fa105961f088f9fc9245.
* feat: add utf8_literal function to create UTF8 literal expressions
* fix: update test_arrow_cast to use column 'b'
* fix: enhance utf8_literal function to handle non-string values
* Add description for utf8_literal vs literal
* docs: clarify utf8_literal function documentation to explain use case
* docs: add clarification comments for utf8_literal usage in arrow_cast
tests
* docs: implement ruff recommendation
* fix ruff errors
* docs: update examples to use utf8_literal in arrow_cast function
* docs: correct typo in comment for utf8_literal usage in test_arrow_cast
* docs: remove redundant comment in test_arrow_cast for clarity
* refactor: rename utf8_literal to string_literal and add alias str_lit
* docs: improve docstring for string_literal function for clarity
* docs: update import statement to include str_lit alias for string_literal
---
docs/source/user-guide/common-operations/functions.rst | 13 ++++++++++++-
python/datafusion/__init__.py | 13 +++++++++++++
python/datafusion/expr.py | 16 ++++++++++++++++
python/datafusion/functions.py | 6 ++++++
python/tests/test_functions.py | 18 +++++++++++++++++-
src/functions.rs | 3 ++-
6 files changed, 66 insertions(+), 3 deletions(-)
diff --git a/docs/source/user-guide/common-operations/functions.rst
b/docs/source/user-guide/common-operations/functions.rst
index ad71c72..12097be 100644
--- a/docs/source/user-guide/common-operations/functions.rst
+++ b/docs/source/user-guide/common-operations/functions.rst
@@ -38,7 +38,7 @@ DataFusion offers mathematical functions such as
:py:func:`~datafusion.functions
.. ipython:: python
- from datafusion import col, literal
+ from datafusion import col, literal, string_literal, str_lit
from datafusion import functions as f
df.select(
@@ -104,6 +104,17 @@ This also includes the functions for regular expressions
like :py:func:`~datafus
f.regexp_replace(col('"Name"'), literal("saur"),
literal("fleur")).alias("flowers")
)
+Casting
+-------
+
+Casting expressions to different data types using
:py:func:`~datafusion.functions.arrow_cast`
+
+.. ipython:: python
+
+ df.select(
+ f.arrow_cast(col('"Total"'),
string_literal("Float64")).alias("total_as_float"),
+ f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int")
+ )
Other
-----
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index e0bc57f..7367b0d 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -107,6 +107,19 @@ def literal(value):
return Expr.literal(value)
+def string_literal(value):
+ """Create a UTF8 literal expression.
+
+ It differs from `literal` which creates a UTF8view literal.
+ """
+ return Expr.string_literal(value)
+
+
+def str_lit(value):
+ """Alias for `string_literal`."""
+ return string_literal(value)
+
+
def lit(value):
"""Create a literal expression."""
return Expr.literal(value)
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index b107243..16add16 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -380,6 +380,22 @@ class Expr:
value = pa.scalar(value)
return Expr(expr_internal.Expr.literal(value))
+ @staticmethod
+ def string_literal(value: str) -> Expr:
+ """Creates a new expression representing a UTF8 literal value.
+
+ It is different from `literal` because it is pa.string() instead of
+ pa.string_view()
+
+ This is needed for cases where DataFusion is expecting a UTF8 instead
of
+ UTF8View literal, like in:
+
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
+ """
+ if isinstance(value, str):
+ value = pa.scalar(value, type=pa.string())
+ return Expr(expr_internal.Expr.literal(value))
+ return Expr.literal(value)
+
@staticmethod
def column(value: str) -> Expr:
"""Creates a new expression representing a column."""
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index f3ee5c0..c0097c6 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -82,6 +82,7 @@ __all__ = [
"array_to_string",
"array_union",
"arrow_typeof",
+ "arrow_cast",
"ascii",
"asin",
"asinh",
@@ -1109,6 +1110,11 @@ def arrow_typeof(arg: Expr) -> Expr:
return Expr(f.arrow_typeof(arg.expr))
+def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
+ """Casts an expression to a specified data type."""
+ return Expr(f.arrow_cast(expr.expr, data_type.expr))
+
+
def random() -> Expr:
"""Returns a random value in the range ``0.0 <= x < 1.0``."""
return Expr(f.random())
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index 0d2fa8f..5dce188 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -23,7 +23,7 @@ from datetime import datetime
from datafusion import SessionContext, column
from datafusion import functions as f
-from datafusion import literal
+from datafusion import literal, string_literal
np.seterr(invalid="ignore")
@@ -907,6 +907,22 @@ def test_temporal_functions(df):
assert result.column(10) == pa.array([31, 26, 2], type=pa.float64())
+def test_arrow_cast(df):
+ df = df.select(
+ # we use `string_literal` to return utf8 instead of `literal` which
returns
+ # utf8view because datafusion.arrow_cast expects a utf8 instead of
utf8view
+ #
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
+ f.arrow_cast(column("b"),
string_literal("Float64")).alias("b_as_float"),
+ f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
+ )
+ result = df.collect()
+ assert len(result) == 1
+ result = result[0]
+
+ assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
+ assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
+
+
def test_case(df):
df = df.select(
f.case(column("b")).when(literal(4),
literal(10)).otherwise(literal(8)),
diff --git a/src/functions.rs b/src/functions.rs
index 5c45028..ccc1981 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -400,7 +400,6 @@ macro_rules! expr_fn {
}
};
}
-
/// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn]
///
/// These functions take a single `Vec<PyExpr>` argument using `pyo3(signature
= (*args))`.
@@ -575,6 +574,7 @@ expr_fn_vec!(r#struct); // Use raw identifier since struct
is a keyword
expr_fn_vec!(named_struct);
expr_fn!(from_unixtime, unixtime);
expr_fn!(arrow_typeof, arg_1);
+expr_fn!(arrow_cast, arg_1 datatype);
expr_fn!(random);
// Array Functions
@@ -867,6 +867,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) ->
PyResult<()> {
m.add_wrapped(wrap_pyfunction!(range))?;
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
+ m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
m.add_wrapped(wrap_pyfunction!(asinh))?;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]