This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e015482  feat: add `cardinality` function to calculate total elements 
in an array (#937)
e015482 is described below

commit e015482750e9e08bd426bfcf649445d53705c51a
Author: kosiew <[email protected]>
AuthorDate: Tue Oct 29 18:16:50 2024 +0800

    feat: add `cardinality` function to calculate total elements in an array 
(#937)
---
 .../user-guide/common-operations/expressions.rst       | 14 ++++++++++++++
 python/datafusion/functions.py                         |  6 ++++++
 python/tests/test_functions.py                         | 18 ++++++++++++++++++
 src/functions.rs                                       |  2 ++
 4 files changed, 40 insertions(+)

diff --git a/docs/source/user-guide/common-operations/expressions.rst 
b/docs/source/user-guide/common-operations/expressions.rst
index 77f3359..23430d3 100644
--- a/docs/source/user-guide/common-operations/expressions.rst
+++ b/docs/source/user-guide/common-operations/expressions.rst
@@ -96,6 +96,20 @@ This function returns a boolean indicating whether the array 
is empty.
 
 In this example, the `is_empty` column will contain `True` for the first row 
and `False` for the second row.
 
+To get the total number of elements in an array, you can use the function 
:py:func:`datafusion.functions.cardinality`.
+This function returns an integer indicating the total number of elements in 
the array.
+
+.. ipython:: python
+
+    from datafusion import SessionContext, col
+    from datafusion.functions import cardinality
+
+    ctx = SessionContext()
+    df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]})
+    df.select(cardinality(col("a")).alias("num_elements"))
+
+In this example, the `num_elements` column will contain `3` for both rows.
+
 Structs
 -------
 
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index 570a6ce..e67ba4a 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -132,6 +132,7 @@ __all__ = [
     "find_in_set",
     "first_value",
     "flatten",
+    "cardinality",
     "floor",
     "from_unixtime",
     "gcd",
@@ -1516,6 +1517,11 @@ def flatten(array: Expr) -> Expr:
     return Expr(f.flatten(array.expr))
 
 
+def cardinality(array: Expr) -> Expr:
+    """Returns the total number of elements in the array."""
+    return Expr(f.cardinality(array.expr))
+
+
 # aggregate functions
 def approx_distinct(
     expression: Expr,
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index e6fd41d..37943e5 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -540,6 +540,24 @@ def test_array_function_flatten():
         )
 
 
+def test_array_function_cardinality():
+    data = [[1, 2, 3], [4, 4, 5, 6]]
+    ctx = SessionContext()
+    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], 
names=["arr"])
+    df = ctx.create_dataframe([[batch]])
+
+    stmt = f.cardinality(column("arr"))
+    py_expr = [len(arr) for arr in data]  # Expected lengths: [3, 3]
+    # assert py_expr lengths
+
+    query_result = df.select(stmt).collect()[0].column(0)
+
+    for a, b in zip(query_result, py_expr):
+        np.testing.assert_array_equal(
+            np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
+        )
+
+
 @pytest.mark.parametrize(
     ("stmt", "py_expr"),
     [
diff --git a/src/functions.rs b/src/functions.rs
index 4facb6c..fe3531b 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -594,6 +594,7 @@ array_fn!(array_intersect, first_array second_array);
 array_fn!(array_union, array1 array2);
 array_fn!(array_except, first_array second_array);
 array_fn!(array_resize, array size value);
+array_fn!(cardinality, array);
 array_fn!(flatten, array);
 array_fn!(range, start stop step);
 
@@ -1030,6 +1031,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> 
PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(array_sort))?;
     m.add_wrapped(wrap_pyfunction!(array_slice))?;
     m.add_wrapped(wrap_pyfunction!(flatten))?;
+    m.add_wrapped(wrap_pyfunction!(cardinality))?;
 
     // Window Functions
     m.add_wrapped(wrap_pyfunction!(lead))?;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to