nmehran commented on issue #30950: URL: https://github.com/apache/arrow/issues/30950#issuecomment-2244460072
[pyarrow_utils.py](https://gist.github.com/nmehran/57f264bd951b2f77af08f760eafea40e) ```python import pyarrow as pa import pyarrow.compute as pc import uuid from typing import Optional, List, Literal def drop_duplicates( table: pa.Table, on: Optional[List[str]] = None, keep: Literal['first', 'last'] = 'first' ) -> pa.Table: """ Remove duplicate rows from a PyArrow table based on specified columns. This function efficiently removes duplicate rows from a PyArrow table, keeping either the first or last occurrence of each unique combination of values in the specified columns. Args: table (pa.Table): The input PyArrow table. on (Optional[List[str]]): List of column names to consider for identifying duplicates. If None, all columns are used. keep (Literal['first', 'last']): Whether to keep the first or last occurrence of duplicates. Returns: pa.Table: A new PyArrow table with duplicates removed. Raises: ValueError: If 'keep' is not 'first' or 'last'. Example: >>> import pyarrow as pa >>> data = [ ... pa.array([1, 2, 2, 3]), ... pa.array(['a', 'b', 'b', 'c']), ... pa.array([10, 20, 30, 40]) ... ] >>> table = pa.Table.from_arrays(data, names=['id', 'name', 'value']) >>> deduped = drop_duplicates(table, on=['id', 'name'], keep='first') >>> print(deduped) pyarrow.Table id: int64 name: string value: int64 ---- id: [1, 2, 3] name: ["a", "b", "c"] value: [10, 20, 40] """ if keep not in ['first', 'last']: raise ValueError("Parameter 'keep' must be either 'first' or 'last'") if not on: on = table.column_names num_rows = table.num_rows # Create row numbers and add as a new column with a unique name row_numbers = pa.array(np.arange(num_rows, dtype=np.int64), type=pa.int64()) # Create int64 row numbers with arange unique_row_index = f"index_{uuid.uuid4().hex}" # Generate a unique column name for row index table_with_index = table.append_column(unique_row_index, row_numbers) # Add the row index as a new column # Group by specified columns and aggregate to get unique indices unique_indices = ( table_with_index .group_by(on, use_threads=False) .aggregate([(unique_row_index, keep)]) ) # Create a boolean mask based on unique row numbers unique_row_numbers = unique_indices.column(f'{unique_row_index}_{keep}') mask = pc.is_in(row_numbers, value_set=unique_row_numbers) # Apply the filter and remove the temporary index column return table_with_index.filter(mask).drop([unique_row_index]) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
