This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e36e8ab  Add arrow cast (#962)
e36e8ab is described below

commit e36e8abfb96211003afd5406cf4d8bddaabaad71
Author: kosiew <[email protected]>
AuthorDate: Tue Jan 7 21:28:25 2025 +0800

    Add arrow cast (#962)
    
    * feat: add data_type parameter to expr_fn macro for arrow_cast function
    
    * feat: add arrow_cast function to cast expressions to specified data types
    
    * docs: add casting section to user guide with examples for arrow_cast 
function
    
    * test: add unit test for arrow_cast function to validate casting to 
Float64 and Int32
    
    * fix: update arrow_cast function to accept Expr type for data_type 
parameter
    
    * fix: update test_arrow_cast to use literal casting for data types
    
    * fix: update arrow_cast function to accept string type for data_type 
parameter
    
    * fix: update arrow_cast function to accept Expr type for data_type 
parameter
    
    * fix: update test_arrow_cast to use literal for data type parameters
    
    * fix: update arrow_cast function to use arg_1 for datatype parameter
    
    * fix: update arrow_cast function to accept string type for data_type 
parameter
    
    * Revert "fix: update arrow_cast function to accept string type for 
data_type parameter"
    
    This reverts commit eba0d320820e8f3f9688781f27b2a5579c0e9949.
    
    * fix: update test_arrow_cast to cast literals to string type for 
arrow_cast function
    
    * Revert "fix: update test_arrow_cast to cast literals to string type for 
arrow_cast function"
    
    This reverts commit 856ff8c4cad0075c282089b5368a7c3fd17f03d8.
    
    * fix: update arrow_cast function to accept string type for data_type 
parameter
    
    * Revert "fix: update arrow_cast function to accept string type for 
data_type parameter"
    
    This reverts commit 9e1ced7fb56c8aec47bc9f540ea5686c7246f022.
    
    * fix: add utf8_literal function to create UTF8 literal expressions in tests
    
    * Revert "fix: add utf8_literal function to create UTF8 literal expressions 
in tests"
    
    This reverts commit 11ed6749e02ab7b34d47fa105961f088f9fc9245.
    
    * feat: add utf8_literal function to create UTF8 literal expressions
    
    * fix: update test_arrow_cast to use column 'b'
    
    * fix: enhance utf8_literal function to handle non-string values
    
    * Add description for utf8_literal vs literal
    
    * docs: clarify utf8_literal function documentation to explain use case
    
    * docs: add clarification comments for utf8_literal usage in arrow_cast 
tests
    
    * docs: implement ruff recommendation
    
    * fix ruff errors
    
    * docs: update examples to use utf8_literal in arrow_cast function
    
    * docs: correct typo in comment for utf8_literal usage in test_arrow_cast
    
    * docs: remove redundant comment in test_arrow_cast for clarity
    
    * refactor: rename utf8_literal to string_literal and add alias str_lit
    
    * docs: improve docstring for string_literal function for clarity
    
    * docs: update import statement to include str_lit alias for string_literal
---
 docs/source/user-guide/common-operations/functions.rst | 13 ++++++++++++-
 python/datafusion/__init__.py                          | 13 +++++++++++++
 python/datafusion/expr.py                              | 16 ++++++++++++++++
 python/datafusion/functions.py                         |  6 ++++++
 python/tests/test_functions.py                         | 18 +++++++++++++++++-
 src/functions.rs                                       |  3 ++-
 6 files changed, 66 insertions(+), 3 deletions(-)

diff --git a/docs/source/user-guide/common-operations/functions.rst 
b/docs/source/user-guide/common-operations/functions.rst
index ad71c72..12097be 100644
--- a/docs/source/user-guide/common-operations/functions.rst
+++ b/docs/source/user-guide/common-operations/functions.rst
@@ -38,7 +38,7 @@ DataFusion offers mathematical functions such as 
:py:func:`~datafusion.functions
 
 .. ipython:: python
 
-    from datafusion import col, literal
+    from datafusion import col, literal, string_literal, str_lit
     from datafusion import functions as f
 
     df.select(
@@ -104,6 +104,17 @@ This also includes the functions for regular expressions 
like :py:func:`~datafus
         f.regexp_replace(col('"Name"'), literal("saur"), 
literal("fleur")).alias("flowers")
     )
 
+Casting
+-------
+
+Casting expressions to different data types using 
:py:func:`~datafusion.functions.arrow_cast`
+
+.. ipython:: python
+
+    df.select(
+        f.arrow_cast(col('"Total"'), 
string_literal("Float64")).alias("total_as_float"),
+        f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int")
+    )
 
 Other
 -----
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index e0bc57f..7367b0d 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -107,6 +107,19 @@ def literal(value):
     return Expr.literal(value)
 
 
+def string_literal(value):
+    """Create a UTF8 literal expression.
+
+    It differs from `literal` which creates a UTF8view literal.
+    """
+    return Expr.string_literal(value)
+
+
+def str_lit(value):
+    """Alias for `string_literal`."""
+    return string_literal(value)
+
+
 def lit(value):
     """Create a literal expression."""
     return Expr.literal(value)
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index b107243..16add16 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -380,6 +380,22 @@ class Expr:
             value = pa.scalar(value)
         return Expr(expr_internal.Expr.literal(value))
 
+    @staticmethod
+    def string_literal(value: str) -> Expr:
+        """Creates a new expression representing a UTF8 literal value.
+
+        It is different from `literal` because it is pa.string() instead of
+        pa.string_view()
+
+        This is needed for cases where DataFusion is expecting a UTF8 instead 
of
+        UTF8View literal, like in:
+        
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
+        """
+        if isinstance(value, str):
+            value = pa.scalar(value, type=pa.string())
+            return Expr(expr_internal.Expr.literal(value))
+        return Expr.literal(value)
+
     @staticmethod
     def column(value: str) -> Expr:
         """Creates a new expression representing a column."""
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index f3ee5c0..c0097c6 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -82,6 +82,7 @@ __all__ = [
     "array_to_string",
     "array_union",
     "arrow_typeof",
+    "arrow_cast",
     "ascii",
     "asin",
     "asinh",
@@ -1109,6 +1110,11 @@ def arrow_typeof(arg: Expr) -> Expr:
     return Expr(f.arrow_typeof(arg.expr))
 
 
+def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
+    """Casts an expression to a specified data type."""
+    return Expr(f.arrow_cast(expr.expr, data_type.expr))
+
+
 def random() -> Expr:
     """Returns a random value in the range ``0.0 <= x < 1.0``."""
     return Expr(f.random())
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index 0d2fa8f..5dce188 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -23,7 +23,7 @@ from datetime import datetime
 
 from datafusion import SessionContext, column
 from datafusion import functions as f
-from datafusion import literal
+from datafusion import literal, string_literal
 
 np.seterr(invalid="ignore")
 
@@ -907,6 +907,22 @@ def test_temporal_functions(df):
     assert result.column(10) == pa.array([31, 26, 2], type=pa.float64())
 
 
+def test_arrow_cast(df):
+    df = df.select(
+        # we use `string_literal` to return utf8 instead of `literal` which 
returns
+        # utf8view because datafusion.arrow_cast expects a utf8 instead of 
utf8view
+        # 
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
+        f.arrow_cast(column("b"), 
string_literal("Float64")).alias("b_as_float"),
+        f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
+    )
+    result = df.collect()
+    assert len(result) == 1
+    result = result[0]
+
+    assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
+    assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
+
+
 def test_case(df):
     df = df.select(
         f.case(column("b")).when(literal(4), 
literal(10)).otherwise(literal(8)),
diff --git a/src/functions.rs b/src/functions.rs
index 5c45028..ccc1981 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -400,7 +400,6 @@ macro_rules! expr_fn {
         }
     };
 }
-
 /// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn]
 ///
 /// These functions take a single `Vec<PyExpr>` argument using `pyo3(signature 
= (*args))`.
@@ -575,6 +574,7 @@ expr_fn_vec!(r#struct); // Use raw identifier since struct 
is a keyword
 expr_fn_vec!(named_struct);
 expr_fn!(from_unixtime, unixtime);
 expr_fn!(arrow_typeof, arg_1);
+expr_fn!(arrow_cast, arg_1 datatype);
 expr_fn!(random);
 
 // Array Functions
@@ -867,6 +867,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> 
PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(range))?;
     m.add_wrapped(wrap_pyfunction!(array_agg))?;
     m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
+    m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
     m.add_wrapped(wrap_pyfunction!(ascii))?;
     m.add_wrapped(wrap_pyfunction!(asin))?;
     m.add_wrapped(wrap_pyfunction!(asinh))?;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to