This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 0905f5f feat: add fill_null/nan (#919)
0905f5f is described below
commit 0905f5fca4b763fc61e5e2093a85ad05e203d7fb
Author: Ion Koutsouris <[email protected]>
AuthorDate: Wed Oct 16 14:36:56 2024 +0200
feat: add fill_null/nan (#919)
---
python/datafusion/expr.py | 12 ++++++++++++
python/datafusion/functions.py | 6 ++++++
python/tests/test_expr.py | 33 ++++++++++++++++++++++++++++++---
src/functions.rs | 6 ++++++
4 files changed, 54 insertions(+), 3 deletions(-)
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 8600627..c4e7713 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -406,6 +406,18 @@ class Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())
+ def fill_nan(self, value: Any | Expr | None = None) -> Expr:
+ """Fill NaN values with a provided value."""
+ if not isinstance(value, Expr):
+ value = Expr.literal(value)
+ return Expr(functions_internal.nanvl(self.expr, value.expr))
+
+ def fill_null(self, value: Any | Expr | None = None) -> Expr:
+ """Fill NULL values with a provided value."""
+ if not isinstance(value, Expr):
+ value = Expr.literal(value)
+ return Expr(functions_internal.nvl(self.expr, value.expr))
+
_to_pyarrow_types = {
float: pa.float64(),
int: pa.int64(),
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index 0401afb..7273219 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -186,6 +186,7 @@ __all__ = [
"min",
"named_struct",
"nanvl",
+ "nvl",
"now",
"nth_value",
"nullif",
@@ -673,6 +674,11 @@ def nanvl(x: Expr, y: Expr) -> Expr:
return Expr(f.nanvl(x.expr, y.expr))
+def nvl(x: Expr, y: Expr) -> Expr:
+ """Returns ``x`` if ``x`` is not ``NULL``. Otherwise returns ``y``."""
+ return Expr(f.nvl(x.expr, y.expr))
+
+
def octet_length(arg: Expr) -> Expr:
"""Returns the number of bytes of a string."""
return Expr(f.octet_length(arg.expr))
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index b58177f..1847ede 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import pyarrow
+import pyarrow as pa
import pytest
from datafusion import SessionContext, col
from datafusion.expr import (
@@ -125,8 +125,8 @@ def test_sort(test_ctx):
def test_relational_expr(test_ctx):
ctx = SessionContext()
- batch = pyarrow.RecordBatch.from_arrays(
- [pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -216,3 +216,30 @@ def test_display_name_deprecation():
# returns appropriate result
assert name == expr.schema_name()
assert name == "foo"
+
+
[email protected]
+def df():
+ ctx = SessionContext()
+
+ # create a RecordBatch and a new DataFrame from it
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None,
8])],
+ names=["a", "b", "c"],
+ )
+
+ return ctx.from_arrow(batch)
+
+
+def test_fill_null(df):
+ df = df.select(
+ col("a").fill_null(100).alias("a"),
+ col("b").fill_null(25).alias("b"),
+ col("c").fill_null(1234).alias("c"),
+ )
+ df.show()
+ result = df.collect()[0]
+
+ assert result.column(0) == pa.array([1, 2, 100])
+ assert result.column(1) == pa.array([4, 25, 6])
+ assert result.column(2) == pa.array([1234, 1234, 8])
diff --git a/src/functions.rs b/src/functions.rs
index 6f8dd7a..24d33af 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -490,6 +490,11 @@ expr_fn!(
x y,
"Returns x if x is not NaN otherwise returns y."
);
+expr_fn!(
+ nvl,
+ x y,
+ "Returns x if x is not NULL otherwise returns y."
+);
expr_fn!(nullif, arg_1 arg_2);
expr_fn!(octet_length, args, "Returns number of bytes in the string. Since
this version of the function accepts type character directly, it will not strip
trailing spaces.");
expr_fn_vec!(overlay);
@@ -913,6 +918,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) ->
PyResult<()> {
m.add_wrapped(wrap_pyfunction!(min))?;
m.add_wrapped(wrap_pyfunction!(named_struct))?;
m.add_wrapped(wrap_pyfunction!(nanvl))?;
+ m.add_wrapped(wrap_pyfunction!(nvl))?;
m.add_wrapped(wrap_pyfunction!(now))?;
m.add_wrapped(wrap_pyfunction!(nullif))?;
m.add_wrapped(wrap_pyfunction!(octet_length))?;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]