This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 697ca2c  Add array functions (#560)
697ca2c is described below

commit 697ca2c4d8a5b02cddcf5108edcae26cd76fe823
Author: Chih Wang <[email protected]>
AuthorDate: Mon Feb 12 23:26:49 2024 +0800

    Add array functions (#560)
    
    * Add array_has, array_has_all and array_has_any
    
    * Add array_position, array_indexof, list_position and list_indexof
    
    * Add array_to_string, array_join, list_to_string and list_join
    
    * Add array_ndims and list_ndims
    
    * Add array_push_back, list_append and list_push_back
    
    * Add array_prepend, array_push_front, list_prepend and list_push_front
    
    * Add array_pop_back and array_pop_front
    
    * Add array_positions and list_positions
    
    * Add array_remove, list_remove, array_remove_n, list_remove_n, 
array_remove_all and list_remove_all
    
    * Add array_repeat
    
    * Add array_replace, list_replace, array_replace_n, list_replace_n, 
array_replace_all, list_replace_all
    
    * Add array_slice and list_slice
---
 datafusion/tests/test_functions.py | 211 ++++++++++++++++++++++++++++++++++++-
 src/functions.rs                   |  78 ++++++++++++++
 2 files changed, 286 insertions(+), 3 deletions(-)

diff --git a/datafusion/tests/test_functions.py 
b/datafusion/tests/test_functions.py
index d0514f8..7e77258 100644
--- a/datafusion/tests/test_functions.py
+++ b/datafusion/tests/test_functions.py
@@ -200,19 +200,62 @@ def test_math_functions():
 
 
 def test_array_functions():
-    data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
+    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
     ctx = SessionContext()
     batch = pa.RecordBatch.from_arrays(
         [np.array(data, dtype=object)], names=["arr"]
     )
     df = ctx.create_dataframe([[batch]])
 
+    def py_indexof(arr, v):
+        try:
+            return arr.index(v) + 1
+        except ValueError:
+            return np.nan
+
+    def py_arr_remove(arr, v, n=None):
+        new_arr = arr[:]
+        found = 0
+        while found != n:
+            try:
+                new_arr.remove(v)
+                found += 1
+            except ValueError:
+                break
+
+        return new_arr
+
+    def py_arr_replace(arr, from_, to, n=None):
+        new_arr = arr[:]
+        found = 0
+        while found != n:
+            try:
+                idx = new_arr.index(from_)
+                new_arr[idx] = to
+                found += 1
+            except ValueError:
+                break
+
+        return new_arr
+
     col = column("arr")
     test_items = [
         [
             f.array_append(col, literal(99.0)),
             lambda: [np.append(arr, 99.0) for arr in data],
         ],
+        [
+            f.array_push_back(col, literal(99.0)),
+            lambda: [np.append(arr, 99.0) for arr in data],
+        ],
+        [
+            f.list_append(col, literal(99.0)),
+            lambda: [np.append(arr, 99.0) for arr in data],
+        ],
+        [
+            f.list_push_back(col, literal(99.0)),
+            lambda: [np.append(arr, 99.0) for arr in data],
+        ],
         [
             f.array_concat(col, col),
             lambda: [np.concatenate([arr, arr]) for arr in data],
@@ -253,12 +296,174 @@ def test_array_functions():
             f.list_length(col),
             lambda: [len(r) for r in data],
         ],
+        [
+            f.array_has(col, literal(1.0)),
+            lambda: [1.0 in r for r in data],
+        ],
+        [
+            f.array_has_all(
+                col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
+            ),
+            lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
+        ],
+        [
+            f.array_has_any(
+                col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
+            ),
+            lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
+        ],
+        [
+            f.array_position(col, literal(1.0)),
+            lambda: [py_indexof(r, 1.0) for r in data],
+        ],
+        [
+            f.array_indexof(col, literal(1.0)),
+            lambda: [py_indexof(r, 1.0) for r in data],
+        ],
+        [
+            f.list_position(col, literal(1.0)),
+            lambda: [py_indexof(r, 1.0) for r in data],
+        ],
+        [
+            f.list_indexof(col, literal(1.0)),
+            lambda: [py_indexof(r, 1.0) for r in data],
+        ],
+        [
+            f.array_positions(col, literal(1.0)),
+            lambda: [
+                [i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data
+            ],
+        ],
+        [
+            f.list_positions(col, literal(1.0)),
+            lambda: [
+                [i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data
+            ],
+        ],
+        [
+            f.array_ndims(col),
+            lambda: [np.array(r).ndim for r in data],
+        ],
+        [
+            f.list_ndims(col),
+            lambda: [np.array(r).ndim for r in data],
+        ],
+        [
+            f.array_prepend(literal(99.0), col),
+            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+        ],
+        [
+            f.array_push_front(literal(99.0), col),
+            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+        ],
+        [
+            f.list_prepend(literal(99.0), col),
+            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+        ],
+        [
+            f.list_push_front(literal(99.0), col),
+            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+        ],
+        [
+            f.array_pop_back(col),
+            lambda: [arr[:-1] for arr in data],
+        ],
+        [
+            f.array_pop_front(col),
+            lambda: [arr[1:] for arr in data],
+        ],
+        [
+            f.array_remove(col, literal(3.0)),
+            lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
+        ],
+        [
+            f.list_remove(col, literal(3.0)),
+            lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
+        ],
+        [
+            f.array_remove_n(col, literal(3.0), literal(2)),
+            lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
+        ],
+        [
+            f.list_remove_n(col, literal(3.0), literal(2)),
+            lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
+        ],
+        [
+            f.array_remove_all(col, literal(3.0)),
+            lambda: [py_arr_remove(arr, 3.0) for arr in data],
+        ],
+        [
+            f.list_remove_all(col, literal(3.0)),
+            lambda: [py_arr_remove(arr, 3.0) for arr in data],
+        ],
+        [
+            f.array_repeat(col, literal(2)),
+            lambda: [[arr] * 2 for arr in data],
+        ],
+        [
+            f.array_replace(col, literal(3.0), literal(4.0)),
+            lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+        ],
+        [
+            f.list_replace(col, literal(3.0), literal(4.0)),
+            lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+        ],
+        [
+            f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
+            lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+        ],
+        [
+            f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
+            lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
+        ],
+        [
+            f.array_replace_all(col, literal(3.0), literal(4.0)),
+            lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
+        ],
+        [
+            f.list_replace_all(col, literal(3.0), literal(4.0)),
+            lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
+        ],
+        [
+            f.array_slice(col, literal(2), literal(4)),
+            lambda: [arr[1:4] for arr in data],
+        ],
+        [
+            f.list_slice(col, literal(-1), literal(2)),
+            lambda: [arr[-1:2] for arr in data],
+        ],
     ]
 
     for stmt, py_expr in test_items:
-        query_result = df.select(stmt).collect()[0].column(0).tolist()
+        query_result = df.select(stmt).collect()[0].column(0)
+        for a, b in zip(query_result, py_expr()):
+            np.testing.assert_array_almost_equal(
+                np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
+            )
+
+    obj_test_items = [
+        [
+            f.array_to_string(col, literal(",")),
+            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+        ],
+        [
+            f.array_join(col, literal(",")),
+            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+        ],
+        [
+            f.list_to_string(col, literal(",")),
+            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+        ],
+        [
+            f.list_join(col, literal(",")),
+            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+        ],
+    ]
+
+    for stmt, py_expr in obj_test_items:
+        query_result = np.array(df.select(stmt).collect()[0].column(0))
         for a, b in zip(query_result, py_expr()):
-            np.testing.assert_array_almost_equal(a, b)
+            assert a == b
 
 
 def test_string_functions(df):
diff --git a/src/functions.rs b/src/functions.rs
index 045e7e0..bb204c3 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -360,6 +360,9 @@ scalar_function!(decode, Decode);
 
 // Array Functions
 scalar_function!(array_append, ArrayAppend);
+scalar_function!(array_push_back, ArrayAppend);
+scalar_function!(list_append, ArrayAppend);
+scalar_function!(list_push_back, ArrayAppend);
 scalar_function!(array_concat, ArrayConcat);
 scalar_function!(array_cat, ArrayConcat);
 scalar_function!(array_dims, ArrayDims);
@@ -370,6 +373,42 @@ scalar_function!(list_element, ArrayElement);
 scalar_function!(list_extract, ArrayElement);
 scalar_function!(array_length, ArrayLength);
 scalar_function!(list_length, ArrayLength);
+scalar_function!(array_has, ArrayHas);
+scalar_function!(array_has_all, ArrayHasAll);
+scalar_function!(array_has_any, ArrayHasAny);
+scalar_function!(array_position, ArrayPosition);
+scalar_function!(array_indexof, ArrayPosition);
+scalar_function!(list_position, ArrayPosition);
+scalar_function!(list_indexof, ArrayPosition);
+scalar_function!(array_positions, ArrayPositions);
+scalar_function!(list_positions, ArrayPositions);
+scalar_function!(array_to_string, ArrayToString);
+scalar_function!(array_join, ArrayToString);
+scalar_function!(list_to_string, ArrayToString);
+scalar_function!(list_join, ArrayToString);
+scalar_function!(array_ndims, ArrayNdims);
+scalar_function!(list_ndims, ArrayNdims);
+scalar_function!(array_prepend, ArrayPrepend);
+scalar_function!(array_push_front, ArrayPrepend);
+scalar_function!(list_prepend, ArrayPrepend);
+scalar_function!(list_push_front, ArrayPrepend);
+scalar_function!(array_pop_back, ArrayPopBack);
+scalar_function!(array_pop_front, ArrayPopFront);
+scalar_function!(array_remove, ArrayRemove);
+scalar_function!(list_remove, ArrayRemove);
+scalar_function!(array_remove_n, ArrayRemoveN);
+scalar_function!(list_remove_n, ArrayRemoveN);
+scalar_function!(array_remove_all, ArrayRemoveAll);
+scalar_function!(list_remove_all, ArrayRemoveAll);
+scalar_function!(array_repeat, ArrayRepeat);
+scalar_function!(array_replace, ArrayReplace);
+scalar_function!(list_replace, ArrayReplace);
+scalar_function!(array_replace_n, ArrayReplaceN);
+scalar_function!(list_replace_n, ArrayReplaceN);
+scalar_function!(array_replace_all, ArrayReplaceAll);
+scalar_function!(list_replace_all, ArrayReplaceAll);
+scalar_function!(array_slice, ArraySlice);
+scalar_function!(list_slice, ArraySlice);
 
 aggregate_function!(approx_distinct, ApproxDistinct);
 aggregate_function!(approx_median, ApproxMedian);
@@ -563,6 +602,9 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
 
     // Array Functions
     m.add_wrapped(wrap_pyfunction!(array_append))?;
+    m.add_wrapped(wrap_pyfunction!(array_push_back))?;
+    m.add_wrapped(wrap_pyfunction!(list_append))?;
+    m.add_wrapped(wrap_pyfunction!(list_push_back))?;
     m.add_wrapped(wrap_pyfunction!(array_concat))?;
     m.add_wrapped(wrap_pyfunction!(array_cat))?;
     m.add_wrapped(wrap_pyfunction!(array_dims))?;
@@ -573,6 +615,42 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(list_extract))?;
     m.add_wrapped(wrap_pyfunction!(array_length))?;
     m.add_wrapped(wrap_pyfunction!(list_length))?;
+    m.add_wrapped(wrap_pyfunction!(array_has))?;
+    m.add_wrapped(wrap_pyfunction!(array_has_all))?;
+    m.add_wrapped(wrap_pyfunction!(array_has_any))?;
+    m.add_wrapped(wrap_pyfunction!(array_position))?;
+    m.add_wrapped(wrap_pyfunction!(array_indexof))?;
+    m.add_wrapped(wrap_pyfunction!(list_position))?;
+    m.add_wrapped(wrap_pyfunction!(list_indexof))?;
+    m.add_wrapped(wrap_pyfunction!(array_positions))?;
+    m.add_wrapped(wrap_pyfunction!(list_positions))?;
+    m.add_wrapped(wrap_pyfunction!(array_to_string))?;
+    m.add_wrapped(wrap_pyfunction!(array_join))?;
+    m.add_wrapped(wrap_pyfunction!(list_to_string))?;
+    m.add_wrapped(wrap_pyfunction!(list_join))?;
+    m.add_wrapped(wrap_pyfunction!(array_ndims))?;
+    m.add_wrapped(wrap_pyfunction!(list_ndims))?;
+    m.add_wrapped(wrap_pyfunction!(array_prepend))?;
+    m.add_wrapped(wrap_pyfunction!(array_push_front))?;
+    m.add_wrapped(wrap_pyfunction!(list_prepend))?;
+    m.add_wrapped(wrap_pyfunction!(list_push_front))?;
+    m.add_wrapped(wrap_pyfunction!(array_pop_back))?;
+    m.add_wrapped(wrap_pyfunction!(array_pop_front))?;
+    m.add_wrapped(wrap_pyfunction!(array_remove))?;
+    m.add_wrapped(wrap_pyfunction!(list_remove))?;
+    m.add_wrapped(wrap_pyfunction!(array_remove_n))?;
+    m.add_wrapped(wrap_pyfunction!(list_remove_n))?;
+    m.add_wrapped(wrap_pyfunction!(array_remove_all))?;
+    m.add_wrapped(wrap_pyfunction!(list_remove_all))?;
+    m.add_wrapped(wrap_pyfunction!(array_repeat))?;
+    m.add_wrapped(wrap_pyfunction!(array_replace))?;
+    m.add_wrapped(wrap_pyfunction!(list_replace))?;
+    m.add_wrapped(wrap_pyfunction!(array_replace_n))?;
+    m.add_wrapped(wrap_pyfunction!(list_replace_n))?;
+    m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
+    m.add_wrapped(wrap_pyfunction!(list_replace_all))?;
+    m.add_wrapped(wrap_pyfunction!(array_slice))?;
+    m.add_wrapped(wrap_pyfunction!(list_slice))?;
 
     Ok(())
 }

Reply via email to