This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new c609dfa3 feat: allow passing a slice to and expression with the []
indexing (#1215)
c609dfa3 is described below
commit c609dfa31abacc8c891e70ab1c0ae474c12789ee
Author: Tim Saucer <[email protected]>
AuthorDate: Fri Sep 5 13:14:49 2025 -0400
feat: allow passing a slice to and expression with the [] indexing (#1215)
* Allow passing a slice to and expression with the [] indexing
* Update documentation
* Add unit test covering expressions in slice
---
.../user-guide/common-operations/expressions.rst | 7 +++++
python/datafusion/expr.py | 31 ++++++++++++++++++++--
python/tests/test_functions.py | 29 +++++++++++++++++++-
3 files changed, 64 insertions(+), 3 deletions(-)
diff --git a/docs/source/user-guide/common-operations/expressions.rst
b/docs/source/user-guide/common-operations/expressions.rst
index 77607e88..7848b4ee 100644
--- a/docs/source/user-guide/common-operations/expressions.rst
+++ b/docs/source/user-guide/common-operations/expressions.rst
@@ -82,6 +82,13 @@ approaches.
Indexing an element of an array via ``[]`` starts at index 0 whereas
:py:func:`~datafusion.functions.array_element` starts at index 1.
+Starting in DataFusion 49.0.0 you can also create slices of array elements
using
+slice syntax from Python.
+
+.. ipython:: python
+
+ df.select(col("a")[1:3].alias("second_two_elements"))
+
To check if an array is empty, you can use the function
:py:func:`datafusion.functions.array_empty` or `datafusion.functions.empty`.
This function returns a boolean indicating whether the array is empty.
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 035fd8ae..b5156040 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -352,17 +352,44 @@ class Expr:
"""Binary not (~)."""
return Expr(self.expr.__invert__())
- def __getitem__(self, key: str | int) -> Expr:
+ def __getitem__(self, key: str | int | slice) -> Expr:
"""Retrieve sub-object.
If ``key`` is a string, returns the subfield of the struct.
If ``key`` is an integer, retrieves the element in the array. Note
that the
- element index begins at ``0``, unlike `array_element` which begins at
``1``.
+ element index begins at ``0``, unlike
+ :py:func:`~datafusion.functions.array_element` which begins at ``1``.
+ If ``key`` is a slice, returns an array that contains a slice of the
+ original array. Similar to integer indexing, this follows Python
convention
+ where the index begins at ``0`` unlike
+ :py:func:`~datafusion.functions.array_slice` which begins at ``1``.
"""
if isinstance(key, int):
return Expr(
functions_internal.array_element(self.expr, Expr.literal(key +
1).expr)
)
+ if isinstance(key, slice):
+ if isinstance(key.start, int):
+ start = Expr.literal(key.start + 1).expr
+ elif isinstance(key.start, Expr):
+ start = (key.start + Expr.literal(1)).expr
+ else:
+ # Default start at the first element, index 1
+ start = Expr.literal(1).expr
+
+ if isinstance(key.stop, int):
+ stop = Expr.literal(key.stop).expr
+ else:
+ stop = key.stop.expr
+
+ if isinstance(key.step, int):
+ step = Expr.literal(key.step).expr
+ elif isinstance(key.step, Expr):
+ step = key.step.expr
+ else:
+ step = key.step
+
+ return Expr(functions_internal.array_slice(self.expr, start, stop,
step))
return Expr(self.expr.__getitem__(key))
def __eq__(self, rhs: object) -> Expr:
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index 52591531..ee19d021 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -494,6 +494,30 @@ def py_flatten(arr):
lambda col: f.list_slice(col, literal(-1), literal(2)),
lambda data: [arr[-1:2] for arr in data],
),
+ (
+ lambda col: col[:3],
+ lambda data: [arr[:3] for arr in data],
+ ),
+ (
+ lambda col: col[1:3],
+ lambda data: [arr[1:3] for arr in data],
+ ),
+ (
+ lambda col: col[1:4:2],
+ lambda data: [arr[1:4:2] for arr in data],
+ ),
+ (
+ lambda col: col[literal(1) : literal(4)],
+ lambda data: [arr[1:4] for arr in data],
+ ),
+ (
+ lambda col: col[column("indices") : column("indices") +
literal(2)],
+ lambda data: [[2.0, 3.0], [], [6.0]],
+ ),
+ (
+ lambda col: col[literal(1) : literal(4) : literal(2)],
+ lambda data: [arr[1:4:2] for arr in data],
+ ),
(
lambda col: f.array_intersect(col, literal([3.0, 4.0])),
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
@@ -534,8 +558,11 @@ def py_flatten(arr):
)
def test_array_functions(stmt, py_expr):
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+ indices = [1, 3, 0]
ctx = SessionContext()
- batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)],
names=["arr"])
+ batch = pa.RecordBatch.from_arrays(
+ [np.array(data, dtype=object), indices], names=["arr", "indices"]
+ )
df = ctx.create_dataframe([[batch]])
col = column("arr")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]