This is an automated email from the ASF dual-hosted git repository.

alenka pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7347eb27c2 GH-40644: [Python] Allow passing a mapping of column names 
to `rename_columns` (#40645)
7347eb27c2 is described below

commit 7347eb27c2a2f05ba1782fcd102f53ccc7a36ee4
Author: Judah Rand <17158624+judahr...@users.noreply.github.com>
AuthorDate: Mon Apr 15 11:43:01 2024 +0100

    GH-40644: [Python] Allow passing a mapping of column names to 
`rename_columns` (#40645)
    
    
    
    ### Rationale for this change
    
    See #40644
    
    ### What changes are included in this PR?
    
    ### Are these changes tested?
    
    Yes.
    
    Tests have been added.
    
    ### Are there any user-facing changes?
    
    * GitHub Issue: #40644
    
    Authored-by: Judah Rand <17158624+judahr...@users.noreply.github.com>
    Signed-off-by: AlenkaF <frim.ale...@gmail.com>
---
 python/pyarrow/table.pxi           | 83 ++++++++++++++++++++++++++++++++++----
 python/pyarrow/tests/test_table.py | 37 +++++++++++++++++
 2 files changed, 112 insertions(+), 8 deletions(-)

diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index d31ea0a5fa..0ba8b4debd 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -2816,8 +2816,17 @@ cdef class RecordBatch(_Tabular):
 
         Parameters
         ----------
-        names : list of str
-            List of new column names.
+        names : list[str] or dict[str, str]
+            List of new column names or mapping of old column names to new 
column names.
+
+            If a mapping of old to new column names is passed, then all 
columns which are
+            found to match a provided old column name will be renamed to the 
new column name.
+            If any column names are not found in the mapping, a KeyError will 
be raised.
+
+        Raises
+        ------
+        KeyError
+            If any of the column names passed in the names mapping do not 
exist.
 
         Returns
         -------
@@ -2838,13 +2847,38 @@ cdef class RecordBatch(_Tabular):
         ----
         n: [2,4,5,100]
         name: ["Flamingo","Horse","Brittle stars","Centipede"]
+        >>> new_names = {"n_legs": "n", "animals": "name"}
+        >>> batch.rename_columns(new_names)
+        pyarrow.RecordBatch
+        n: int64
+        name: string
+        ----
+        n: [2,4,5,100]
+        name: ["Flamingo","Horse","Brittle stars","Centipede"]
         """
         cdef:
             shared_ptr[CRecordBatch] c_batch
             vector[c_string] c_names
 
-        for name in names:
-            c_names.push_back(tobytes(name))
+        if isinstance(names, list):
+            for name in names:
+                c_names.push_back(tobytes(name))
+        elif isinstance(names, dict):
+            idx_to_new_name = {}
+            for name, new_name in names.items():
+                indices = self.schema.get_all_field_indices(name)
+
+                if not indices:
+                    raise KeyError("Column {!r} not found".format(name))
+
+                for index in indices:
+                    idx_to_new_name[index] = new_name
+
+            for i in range(self.num_columns):
+                new_name = idx_to_new_name.get(i, self.column_names[i])
+                c_names.push_back(tobytes(new_name))
+        else:
+            raise TypeError(f"names must be a list or dict not 
{type(names)!r}")
 
         with nogil:
             c_batch = GetResultValue(self.batch.RenameColumns(move(c_names)))
@@ -5215,8 +5249,17 @@ cdef class Table(_Tabular):
 
         Parameters
         ----------
-        names : list of str
-            List of new column names.
+        names : list[str] or dict[str, str]
+            List of new column names or mapping of old column names to new 
column names.
+
+            If a mapping of old to new column names is passed, then all 
columns which are
+            found to match a provided old column name will be renamed to the 
new column name.
+            If any column names are not found in the mapping, a KeyError will 
be raised.
+
+        Raises
+        ------
+        KeyError
+            If any of the column names passed in the names mapping do not 
exist.
 
         Returns
         -------
@@ -5237,13 +5280,37 @@ cdef class Table(_Tabular):
         ----
         n: [[2,4,5,100]]
         name: [["Flamingo","Horse","Brittle stars","Centipede"]]
+        >>> new_names = {"n_legs": "n", "animals": "name"}
+        >>> table.rename_columns(new_names)
+        pyarrow.Table
+        n: int64
+        name: string
+        ----
+        n: [[2,4,5,100]]
+        name: [["Flamingo","Horse","Brittle stars","Centipede"]]
         """
         cdef:
             shared_ptr[CTable] c_table
             vector[c_string] c_names
 
-        for name in names:
-            c_names.push_back(tobytes(name))
+        if isinstance(names, list):
+            for name in names:
+                c_names.push_back(tobytes(name))
+        elif isinstance(names, dict):
+            idx_to_new_name = {}
+            for name, new_name in names.items():
+                indices = self.schema.get_all_field_indices(name)
+
+                if not indices:
+                    raise KeyError("Column {!r} not found".format(name))
+
+                for index in indices:
+                    idx_to_new_name[index] = new_name
+
+            for i in range(self.num_columns):
+                c_names.push_back(tobytes(idx_to_new_name.get(i, 
self.schema[i].name)))
+        else:
+            raise TypeError(f"names must be a list or dict not 
{type(names)!r}")
 
         with nogil:
             c_table = GetResultValue(self.table.RenameColumns(move(c_names)))
diff --git a/python/pyarrow/tests/test_table.py 
b/python/pyarrow/tests/test_table.py
index 539da0e685..7a140d4132 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -1737,6 +1737,43 @@ def test_table_rename_columns(cls):
     expected = cls.from_arrays(data, names=['eh', 'bee', 'sea'])
     assert t2.equals(expected)
 
+    message = "names must be a list or dict not <class 'str'>"
+    with pytest.raises(TypeError, match=message):
+        table.rename_columns('not a list')
+
+
+@pytest.mark.parametrize(
+    ('cls'),
+    [
+        (pa.Table),
+        (pa.RecordBatch)
+    ]
+)
+def test_table_rename_columns_mapping(cls):
+    data = [
+        pa.array(range(5)),
+        pa.array([-10, -5, 0, 5, 10]),
+        pa.array(range(5, 10))
+    ]
+    table = cls.from_arrays(data, names=['a', 'b', 'c'])
+    assert table.column_names == ['a', 'b', 'c']
+
+    expected = cls.from_arrays(data, names=['eh', 'b', 'sea'])
+    t1 = table.rename_columns({'a': 'eh', 'c': 'sea'})
+    t1.validate()
+    assert t1 == expected
+
+    # Test renaming duplicate column names
+    table = cls.from_arrays(data, names=['a', 'a', 'c'])
+    expected = cls.from_arrays(data, names=['eh', 'eh', 'sea'])
+    t2 = table.rename_columns({'a': 'eh', 'c': 'sea'})
+    t2.validate()
+    assert t2 == expected
+
+    # Test column not found
+    with pytest.raises(KeyError, match=r"Column 'd' not found"):
+        table.rename_columns({'a': 'eh', 'd': 'sea'})
+
 
 def test_table_flatten():
     ty1 = pa.struct([pa.field('x', pa.int16()),

Reply via email to