nmehran commented on issue #30950:
URL: https://github.com/apache/arrow/issues/30950#issuecomment-2246193955

   
[pyarrow_drop_duplicates.py](https://gist.github.com/nmehran/57f264bd951b2f77af08f760eafea40e)
   
   ```python
   import time
   from typing import List, Literal, Tuple, Callable, Optional
   from uuid import uuid4
   
   import numpy as np
   import pyarrow as pa
   import pyarrow.compute as pc
   
   
   def drop_duplicates(
           table: pa.Table,
           on: Optional[List[str]] = None,
           keep: Literal['first', 'last'] = 'first'
   ) -> pa.Table:
       """
       Remove duplicate rows from a PyArrow table based on specified columns.
       This function efficiently removes duplicate rows from a PyArrow table,
       keeping either the first or last occurrence of each unique combination
       of values in the specified columns.
       Args:
           table (pa.Table): The input PyArrow table.
           on (Optional[List[str]]): List of column names to consider for 
identifying duplicates.
               If None, all columns are used.
           keep (Literal['first', 'last']): Whether to keep the first or last 
occurrence of duplicates.
       Returns:
           pa.Table: A new PyArrow table with duplicates removed.
       Raises:
           ValueError: If 'keep' is not 'first' or 'last'.
           TypeError: If 'table' is not a PyArrow Table.
       Example:
           >>> import pyarrow as pa
           >>> data = [
           ...     pa.array([1, 2, 2, 3]),
           ...     pa.array(['a', 'b', 'b', 'c']),
           ...     pa.array([10, 20, 30, 40])
           ... ]
           >>> table = pa.Table.from_arrays(data, names=['id', 'name', 'value'])
           >>> deduped = drop_duplicates(table, on=['id', 'name'], keep='first')
           >>> print(deduped)
           pyarrow.Table
           id: int64
           name: string
           value: int64
           ----
           id: [1, 2, 3]
           name: ["a", "b", "c"]
           value: [10, 20, 40]
       """
       if not isinstance(table, pa.Table):
           raise TypeError("Parameter 'table' must be a PyArrow Table")
   
       if keep not in ['first', 'last']:
           raise ValueError("Parameter 'keep' must be either 'first' or 'last'")
   
       if not on:
           on = table.column_names
   
       # Generate a unique column name for row index
       index_column = f"index_{uuid4().hex}"
       index_aggregate_column = f'{index_column}_{keep}'
   
       # Create row numbers
       num_rows = table.num_rows
       row_numbers = pa.array(np.arange(num_rows, dtype=np.int64))
   
       # Append row numbers, group by specified columns, and aggregate
       unique_indices = (
           table.append_column(index_column, row_numbers)
           .group_by(on, use_threads=False)
           .aggregate([(index_column, keep)])
       ).column(index_aggregate_column)
   
       return pc.take(table, unique_indices, boundscheck=False)
       
   
   def drop_duplicates_filter(table, on=None, keep='first'):
       if not on:
           on = table.column_names
       row_numbers = pa.array(np.arange(table.num_rows, dtype=np.int64))
       index_column = f"index_{uuid4().hex}"
       index_aggregate_column = f'{index_column}_{keep}'
       table_with_index = table.append_column(index_column, row_numbers)
       unique_indices = table_with_index.group_by(on, 
use_threads=False).aggregate([(index_column, keep)])
       unique_row_numbers = unique_indices.column(index_aggregate_column)
       mask = pc.is_in(row_numbers, value_set=unique_row_numbers)
       return table.filter(mask)
   
   
   def drop_duplicates_join(table, on=None, keep='first'):
       if not on:
           on = table.column_names
       index_column = f"index_{uuid4().hex}"
       index_aggregate_column = f'{index_column}_{keep}'
       row_numbers = pa.array(np.arange(table.num_rows, dtype=np.int64))
       table_with_index = table.append_column(index_column, row_numbers)
       unique_indices = table_with_index.group_by(on, 
use_threads=False).aggregate([(index_column, keep)])
       return table_with_index.join(
           unique_indices,
           keys=index_column,
           right_keys=index_aggregate_column,
           join_type='left semi',
           use_threads=True
       ).drop(index_column)
   ```
   
   ### Results
   
   ```
   Benchmarking with 10,000,000 rows and 10,000 groups
   
   Benchmarking with keep='first':
   Filter method time: 2.8038 seconds
   Join method time: 1.9018 seconds
   Take method time: 1.6372 seconds
   
   Benchmarking with keep='last':
   Filter method time: 2.8951 seconds
   Join method time: 1.9269 seconds
   Take method time: 1.6332 seconds
   ```
   
   ### Findings
   
   1. Both Join and Take methods consistently outperformed the current Filter 
method.
   2. Take method showed best performance, reducing execution time by ~43%.
   3. Join method also showed significant improvement, reducing time by ~33%.
   4. Performance gains were consistent for both `keep='first'` and 
`keep='last'`.
   5. The filter and take method produce identical results, but the join method 
uses threading and does not preserve order.
   
   See [gist](https://gist.github.com/nmehran/57f264bd951b2f77af08f760eafea40e) 
for full implementation.  Leave a comment in the gist if you can manage to make 
this more efficient.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to