This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d834d65c3 GH-36709: [Python] Allow to specify use_threads=False in 
Table.group_by to have stable ordering (#36768)
7d834d65c3 is described below

commit 7d834d65c37c17d1c19bfb497eadb983893c9ea0
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Thu Oct 5 09:21:56 2023 +0200

    GH-36709: [Python] Allow to specify use_threads=False in Table.group_by to 
have stable ordering (#36768)
    
    ### Rationale for this change
    
    Add a `use_threads` keyword to the `group_by` method on Table, and passes 
this through to the Declaration.to_table call. This also allows to specify 
`use_threads=False` to get stable ordering of the output, and which is also 
required to specify for certain aggregations (eg `"first"` will fail with the 
default of `use_threads=True`)
    
    ### Are these changes tested?
    
    Yes, added a test (similar to the one we have for this for `filter`), that 
would fail (>50% of the times) if the output was no longer ordered.
    
    * Closes: #36709
    
    Authored-by: Joris Van den Bossche <[email protected]>
    Signed-off-by: Joris Van den Bossche <[email protected]>
---
 python/pyarrow/acero.py                |  4 ++--
 python/pyarrow/table.pxi               | 20 +++++++++++++++-----
 python/pyarrow/tests/test_exec_plan.py | 14 ++++++++++++++
 python/pyarrow/tests/test_table.py     | 15 +++++++++++++++
 4 files changed, 46 insertions(+), 7 deletions(-)

diff --git a/python/pyarrow/acero.py b/python/pyarrow/acero.py
index 63da0a3786..0609e45753 100644
--- a/python/pyarrow/acero.py
+++ b/python/pyarrow/acero.py
@@ -299,10 +299,10 @@ def _sort_source(table_or_dataset, sort_keys, 
output_type=Table, **kwargs):
         raise TypeError("Unsupported output type")
 
 
-def _group_by(table, aggregates, keys):
+def _group_by(table, aggregates, keys, use_threads=True):
 
     decl = Declaration.from_sequence([
         Declaration("table_source", TableSourceNodeOptions(table)),
         Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
     ])
-    return decl.to_table(use_threads=True)
+    return decl.to_table(use_threads=use_threads)
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 2eae38485d..36601130b3 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -4599,8 +4599,9 @@ cdef class Table(_Tabular):
         """
         return self.drop_columns(columns)
 
-    def group_by(self, keys):
-        """Declare a grouping over the columns of the table.
+    def group_by(self, keys, use_threads=True):
+        """
+        Declare a grouping over the columns of the table.
 
         Resulting grouping can then be used to perform aggregations
         with a subsequent ``aggregate()`` method.
@@ -4609,6 +4610,9 @@ cdef class Table(_Tabular):
         ----------
         keys : str or list[str]
             Name of the columns that should be used as the grouping key.
+        use_threads : bool, default True
+            Whether to use multithreading or not. When set to True (the
+            default), no stable ordering of the output is guaranteed.
 
         Returns
         -------
@@ -4635,7 +4639,7 @@ cdef class Table(_Tabular):
         year: [[2020,2022,2021,2019]]
         n_legs_sum: [[2,6,104,5]]
         """
-        return TableGroupBy(self, keys)
+        return TableGroupBy(self, keys, use_threads=use_threads)
 
     def join(self, right_table, keys, right_keys=None, join_type="left outer",
              left_suffix=None, right_suffix=None, coalesce_keys=True,
@@ -5183,6 +5187,9 @@ class TableGroupBy:
         Input table to execute the aggregation on.
     keys : str or list[str]
         Name of the grouped columns.
+    use_threads : bool, default True
+        Whether to use multithreading or not. When set to True (the default),
+        no stable ordering of the output is guaranteed.
 
     Examples
     --------
@@ -5208,12 +5215,13 @@ class TableGroupBy:
     values_sum: [[3,7,5]]
     """
 
-    def __init__(self, table, keys):
+    def __init__(self, table, keys, use_threads=True):
         if isinstance(keys, str):
             keys = [keys]
 
         self._table = table
         self.keys = keys
+        self._use_threads = use_threads
 
     def aggregate(self, aggregations):
         """
@@ -5328,4 +5336,6 @@ list[tuple(str, str, FunctionOptions)]
                 aggr_name = "_".join(target) + "_" + func_nohash
             group_by_aggrs.append((target, func, opt, aggr_name))
 
-        return _pac()._group_by(self._table, group_by_aggrs, self.keys)
+        return _pac()._group_by(
+            self._table, group_by_aggrs, self.keys, 
use_threads=self._use_threads
+        )
diff --git a/python/pyarrow/tests/test_exec_plan.py 
b/python/pyarrow/tests/test_exec_plan.py
index 58c618179b..d85a2c2152 100644
--- a/python/pyarrow/tests/test_exec_plan.py
+++ b/python/pyarrow/tests/test_exec_plan.py
@@ -321,3 +321,17 @@ def test_join_extension_array_column():
     result = _perform_join(
         "left outer", t1, ["colB"], t3, ["colC"])
     assert result["colB"] == pa.chunked_array(ext_array)
+
+
+def test_group_by_ordering():
+    # GH-36709 - preserve ordering in groupby by setting use_threads=False
+    table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a'] * 4})
+    table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b'] * 4})
+    table = pa.concat_tables([table1, table2])
+
+    for _ in range(50):
+        # 50 seems to consistently cause errors when order is not preserved.
+        # If the order problem is reintroduced this test will become flaky
+        # which is still a signal that the order is not preserved.
+        result = table.group_by("b", use_threads=False).aggregate([])
+        assert result["b"] == pa.chunked_array([["a"], ["b"]])
diff --git a/python/pyarrow/tests/test_table.py 
b/python/pyarrow/tests/test_table.py
index f93c6bbc2c..b9e0d69219 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -2175,6 +2175,21 @@ def test_table_group_by():
     }
 
 
[email protected]
+def test_table_group_by_first():
+    # "first" is an ordered aggregation -> requires to specify 
use_threads=False
+    table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a', 'b'] * 2})
+    table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b', 'a'] * 2})
+    table = pa.concat_tables([table1, table2])
+
+    with pytest.raises(NotImplementedError):
+        table.group_by("b").aggregate([("a", "first")])
+
+    result = table.group_by("b", use_threads=False).aggregate([("a", "first")])
+    expected = pa.table({"b": ["a", "b"], "a_first": [1, 2]})
+    assert result.equals(expected)
+
+
 def test_table_to_recordbatchreader():
     table = pa.Table.from_pydict({'x': [1, 2, 3]})
     reader = table.to_reader()

Reply via email to