This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 0905f5f  feat: add fill_null/nan (#919)
0905f5f is described below

commit 0905f5fca4b763fc61e5e2093a85ad05e203d7fb
Author: Ion Koutsouris <[email protected]>
AuthorDate: Wed Oct 16 14:36:56 2024 +0200

    feat: add fill_null/nan (#919)
---
 python/datafusion/expr.py      | 12 ++++++++++++
 python/datafusion/functions.py |  6 ++++++
 python/tests/test_expr.py      | 33 ++++++++++++++++++++++++++++++---
 src/functions.rs               |  6 ++++++
 4 files changed, 54 insertions(+), 3 deletions(-)

diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 8600627..c4e7713 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -406,6 +406,18 @@ class Expr:
         """Returns ``True`` if this expression is not null."""
         return Expr(self.expr.is_not_null())
 
+    def fill_nan(self, value: Any | Expr | None = None) -> Expr:
+        """Fill NaN values with a provided value."""
+        if not isinstance(value, Expr):
+            value = Expr.literal(value)
+        return Expr(functions_internal.nanvl(self.expr, value.expr))
+
+    def fill_null(self, value: Any | Expr | None = None) -> Expr:
+        """Fill NULL values with a provided value."""
+        if not isinstance(value, Expr):
+            value = Expr.literal(value)
+        return Expr(functions_internal.nvl(self.expr, value.expr))
+
     _to_pyarrow_types = {
         float: pa.float64(),
         int: pa.int64(),
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index 0401afb..7273219 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -186,6 +186,7 @@ __all__ = [
     "min",
     "named_struct",
     "nanvl",
+    "nvl",
     "now",
     "nth_value",
     "nullif",
@@ -673,6 +674,11 @@ def nanvl(x: Expr, y: Expr) -> Expr:
     return Expr(f.nanvl(x.expr, y.expr))
 
 
+def nvl(x: Expr, y: Expr) -> Expr:
+    """Returns ``x`` if ``x`` is not ``NULL``. Otherwise returns ``y``."""
+    return Expr(f.nvl(x.expr, y.expr))
+
+
 def octet_length(arg: Expr) -> Expr:
     """Returns the number of bytes of a string."""
     return Expr(f.octet_length(arg.expr))
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index b58177f..1847ede 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import pyarrow
+import pyarrow as pa
 import pytest
 from datafusion import SessionContext, col
 from datafusion.expr import (
@@ -125,8 +125,8 @@ def test_sort(test_ctx):
 def test_relational_expr(test_ctx):
     ctx = SessionContext()
 
-    batch = pyarrow.RecordBatch.from_arrays(
-        [pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
         names=["a", "b"],
     )
     df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -216,3 +216,30 @@ def test_display_name_deprecation():
     # returns appropriate result
     assert name == expr.schema_name()
     assert name == "foo"
+
+
[email protected]
+def df():
+    ctx = SessionContext()
+
+    # create a RecordBatch and a new DataFrame from it
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 
8])],
+        names=["a", "b", "c"],
+    )
+
+    return ctx.from_arrow(batch)
+
+
+def test_fill_null(df):
+    df = df.select(
+        col("a").fill_null(100).alias("a"),
+        col("b").fill_null(25).alias("b"),
+        col("c").fill_null(1234).alias("c"),
+    )
+    df.show()
+    result = df.collect()[0]
+
+    assert result.column(0) == pa.array([1, 2, 100])
+    assert result.column(1) == pa.array([4, 25, 6])
+    assert result.column(2) == pa.array([1234, 1234, 8])
diff --git a/src/functions.rs b/src/functions.rs
index 6f8dd7a..24d33af 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -490,6 +490,11 @@ expr_fn!(
     x y,
     "Returns x if x is not NaN otherwise returns y."
 );
+expr_fn!(
+    nvl,
+    x y,
+    "Returns x if x is not NULL otherwise returns y."
+);
 expr_fn!(nullif, arg_1 arg_2);
 expr_fn!(octet_length, args, "Returns number of bytes in the string. Since 
this version of the function accepts type character directly, it will not strip 
trailing spaces.");
 expr_fn_vec!(overlay);
@@ -913,6 +918,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> 
PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(min))?;
     m.add_wrapped(wrap_pyfunction!(named_struct))?;
     m.add_wrapped(wrap_pyfunction!(nanvl))?;
+    m.add_wrapped(wrap_pyfunction!(nvl))?;
     m.add_wrapped(wrap_pyfunction!(now))?;
     m.add_wrapped(wrap_pyfunction!(nullif))?;
     m.add_wrapped(wrap_pyfunction!(octet_length))?;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to