This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 70c099a  feat: add `cast` to DataFrame (#916)
70c099a is described below

commit 70c099aad8ec337ef88e27c125a8eeba328d62de
Author: Ion Koutsouris <[email protected]>
AuthorDate: Mon Oct 21 19:21:25 2024 +0200

    feat: add `cast` to DataFrame (#916)
    
    * feat: add with_columns
    
    * feat: add top level cast
    
    * chore: improve docstring
    
    ---------
    
    Co-authored-by: Tim Saucer <[email protected]>
---
 python/datafusion/dataframe.py | 13 +++++++++++++
 python/tests/test_dataframe.py |  9 +++++++++
 2 files changed, 22 insertions(+)

diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 9c0953c..3ed6d40 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -21,6 +21,7 @@ See :ref:`user_guide_concepts` in the online documentation 
for more information.
 
 from __future__ import annotations
 
+
 from typing import Any, Iterable, List, Literal, TYPE_CHECKING
 from datafusion.record_batch import RecordBatchStream
 from typing_extensions import deprecated
@@ -267,6 +268,18 @@ class DataFrame:
         exprs_raw = [sort_or_default(expr) for expr in exprs]
         return DataFrame(self.df.sort(*exprs_raw))
 
+    def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
+        """Cast one or more columns to a different data type.
+
+        Args:
+            mapping: Mapped with column as key and column dtype as value.
+
+        Returns:
+            DataFrame after casting columns
+        """
+        exprs = [Expr.column(col).cast(dtype) for col, dtype in 
mapping.items()]
+        return self.with_columns(exprs)
+
     def limit(self, count: int, offset: int = 0) -> DataFrame:
         """Return a new :py:class:`DataFrame` with a limited number of rows.
 
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index 0d4a7dc..bb408c9 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -247,6 +247,15 @@ def test_with_columns(df):
     assert result.column(6) == pa.array([5, 7, 9])
 
 
+def test_cast(df):
+    df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
+    expected = pa.schema(
+        [("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())]
+    )
+
+    assert df.schema() == expected
+
+
 def test_with_column_renamed(df):
     df = df.with_column("c", column("a") + 
column("b")).with_column_renamed("c", "sum")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to