nmehran commented on issue #30950:
URL: https://github.com/apache/arrow/issues/30950#issuecomment-2244460072

   
[pyarrow_utils.py](https://gist.github.com/nmehran/57f264bd951b2f77af08f760eafea40e)
   
   ```python
   import pyarrow as pa
   import pyarrow.compute as pc
   import uuid
   from typing import Optional, List, Literal
   
   
   def drop_duplicates(
       table: pa.Table,
       on: Optional[List[str]] = None,
       keep: Literal['first', 'last'] = 'first'
   ) -> pa.Table:
       """
       Remove duplicate rows from a PyArrow table based on specified columns.
   
       This function efficiently removes duplicate rows from a PyArrow table,
       keeping either the first or last occurrence of each unique combination
       of values in the specified columns.
   
       Args:
           table (pa.Table): The input PyArrow table.
           on (Optional[List[str]]): List of column names to consider for 
identifying duplicates.
               If None, all columns are used.
           keep (Literal['first', 'last']): Whether to keep the first or last 
occurrence of duplicates.
   
       Returns:
           pa.Table: A new PyArrow table with duplicates removed.
   
       Raises:
           ValueError: If 'keep' is not 'first' or 'last'.
   
       Example:
           >>> import pyarrow as pa
           >>> data = [
           ...     pa.array([1, 2, 2, 3]),
           ...     pa.array(['a', 'b', 'b', 'c']),
           ...     pa.array([10, 20, 30, 40])
           ... ]
           >>> table = pa.Table.from_arrays(data, names=['id', 'name', 'value'])
           >>> deduped = drop_duplicates(table, on=['id', 'name'], keep='first')
           >>> print(deduped)
           pyarrow.Table
           id: int64
           name: string
           value: int64
           ----
           id: [1, 2, 3]
           name: ["a", "b", "c"]
           value: [10, 20, 40]
       """
       if keep not in ['first', 'last']:
           raise ValueError("Parameter 'keep' must be either 'first' or 'last'")
   
       if not on:
           on = table.column_names
   
       num_rows = table.num_rows
   
       # Create row numbers and add as a new column with a unique name
       row_numbers = pa.array(np.arange(num_rows, dtype=np.int64), 
type=pa.int64())  # Create int64 row numbers with arange
       unique_row_index = f"index_{uuid.uuid4().hex}"  # Generate a unique 
column name for row index
       table_with_index = table.append_column(unique_row_index, row_numbers)  # 
Add the row index as a new column
   
       # Group by specified columns and aggregate to get unique indices
       unique_indices = (
           table_with_index
           .group_by(on, use_threads=False)
           .aggregate([(unique_row_index, keep)])
       )
   
       # Create a boolean mask based on unique row numbers
       unique_row_numbers = unique_indices.column(f'{unique_row_index}_{keep}')
       mask = pc.is_in(row_numbers, value_set=unique_row_numbers)
   
       # Apply the filter and remove the temporary index column
       return table_with_index.filter(mask).drop([unique_row_index])
       


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to