This is an automated email from the ASF dual-hosted git repository.
alenka pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 7347eb27c2 GH-40644: [Python] Allow passing a mapping of column names
to `rename_columns` (#40645)
7347eb27c2 is described below
commit 7347eb27c2a2f05ba1782fcd102f53ccc7a36ee4
Author: Judah Rand <[email protected]>
AuthorDate: Mon Apr 15 11:43:01 2024 +0100
GH-40644: [Python] Allow passing a mapping of column names to
`rename_columns` (#40645)
### Rationale for this change
See #40644
### What changes are included in this PR?
### Are these changes tested?
Yes.
Tests have been added.
### Are there any user-facing changes?
* GitHub Issue: #40644
Authored-by: Judah Rand <[email protected]>
Signed-off-by: AlenkaF <[email protected]>
---
python/pyarrow/table.pxi | 83 ++++++++++++++++++++++++++++++++++----
python/pyarrow/tests/test_table.py | 37 +++++++++++++++++
2 files changed, 112 insertions(+), 8 deletions(-)
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index d31ea0a5fa..0ba8b4debd 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -2816,8 +2816,17 @@ cdef class RecordBatch(_Tabular):
Parameters
----------
- names : list of str
- List of new column names.
+ names : list[str] or dict[str, str]
+ List of new column names or mapping of old column names to new
column names.
+
+ If a mapping of old to new column names is passed, then all
columns which are
+ found to match a provided old column name will be renamed to the
new column name.
+ If any column names are not found in the mapping, a KeyError will
be raised.
+
+ Raises
+ ------
+ KeyError
+ If any of the column names passed in the names mapping do not
exist.
Returns
-------
@@ -2838,13 +2847,38 @@ cdef class RecordBatch(_Tabular):
----
n: [2,4,5,100]
name: ["Flamingo","Horse","Brittle stars","Centipede"]
+ >>> new_names = {"n_legs": "n", "animals": "name"}
+ >>> batch.rename_columns(new_names)
+ pyarrow.RecordBatch
+ n: int64
+ name: string
+ ----
+ n: [2,4,5,100]
+ name: ["Flamingo","Horse","Brittle stars","Centipede"]
"""
cdef:
shared_ptr[CRecordBatch] c_batch
vector[c_string] c_names
- for name in names:
- c_names.push_back(tobytes(name))
+ if isinstance(names, list):
+ for name in names:
+ c_names.push_back(tobytes(name))
+ elif isinstance(names, dict):
+ idx_to_new_name = {}
+ for name, new_name in names.items():
+ indices = self.schema.get_all_field_indices(name)
+
+ if not indices:
+ raise KeyError("Column {!r} not found".format(name))
+
+ for index in indices:
+ idx_to_new_name[index] = new_name
+
+ for i in range(self.num_columns):
+ new_name = idx_to_new_name.get(i, self.column_names[i])
+ c_names.push_back(tobytes(new_name))
+ else:
+ raise TypeError(f"names must be a list or dict not
{type(names)!r}")
with nogil:
c_batch = GetResultValue(self.batch.RenameColumns(move(c_names)))
@@ -5215,8 +5249,17 @@ cdef class Table(_Tabular):
Parameters
----------
- names : list of str
- List of new column names.
+ names : list[str] or dict[str, str]
+ List of new column names or mapping of old column names to new
column names.
+
+ If a mapping of old to new column names is passed, then all
columns which are
+ found to match a provided old column name will be renamed to the
new column name.
+ If any column names are not found in the mapping, a KeyError will
be raised.
+
+ Raises
+ ------
+ KeyError
+ If any of the column names passed in the names mapping do not
exist.
Returns
-------
@@ -5237,13 +5280,37 @@ cdef class Table(_Tabular):
----
n: [[2,4,5,100]]
name: [["Flamingo","Horse","Brittle stars","Centipede"]]
+ >>> new_names = {"n_legs": "n", "animals": "name"}
+ >>> table.rename_columns(new_names)
+ pyarrow.Table
+ n: int64
+ name: string
+ ----
+ n: [[2,4,5,100]]
+ name: [["Flamingo","Horse","Brittle stars","Centipede"]]
"""
cdef:
shared_ptr[CTable] c_table
vector[c_string] c_names
- for name in names:
- c_names.push_back(tobytes(name))
+ if isinstance(names, list):
+ for name in names:
+ c_names.push_back(tobytes(name))
+ elif isinstance(names, dict):
+ idx_to_new_name = {}
+ for name, new_name in names.items():
+ indices = self.schema.get_all_field_indices(name)
+
+ if not indices:
+ raise KeyError("Column {!r} not found".format(name))
+
+ for index in indices:
+ idx_to_new_name[index] = new_name
+
+ for i in range(self.num_columns):
+ c_names.push_back(tobytes(idx_to_new_name.get(i,
self.schema[i].name)))
+ else:
+ raise TypeError(f"names must be a list or dict not
{type(names)!r}")
with nogil:
c_table = GetResultValue(self.table.RenameColumns(move(c_names)))
diff --git a/python/pyarrow/tests/test_table.py
b/python/pyarrow/tests/test_table.py
index 539da0e685..7a140d4132 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -1737,6 +1737,43 @@ def test_table_rename_columns(cls):
expected = cls.from_arrays(data, names=['eh', 'bee', 'sea'])
assert t2.equals(expected)
+ message = "names must be a list or dict not <class 'str'>"
+ with pytest.raises(TypeError, match=message):
+ table.rename_columns('not a list')
+
+
[email protected](
+ ('cls'),
+ [
+ (pa.Table),
+ (pa.RecordBatch)
+ ]
+)
+def test_table_rename_columns_mapping(cls):
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = cls.from_arrays(data, names=['a', 'b', 'c'])
+ assert table.column_names == ['a', 'b', 'c']
+
+ expected = cls.from_arrays(data, names=['eh', 'b', 'sea'])
+ t1 = table.rename_columns({'a': 'eh', 'c': 'sea'})
+ t1.validate()
+ assert t1 == expected
+
+ # Test renaming duplicate column names
+ table = cls.from_arrays(data, names=['a', 'a', 'c'])
+ expected = cls.from_arrays(data, names=['eh', 'eh', 'sea'])
+ t2 = table.rename_columns({'a': 'eh', 'c': 'sea'})
+ t2.validate()
+ assert t2 == expected
+
+ # Test column not found
+ with pytest.raises(KeyError, match=r"Column 'd' not found"):
+ table.rename_columns({'a': 'eh', 'd': 'sea'})
+
def test_table_flatten():
ty1 = pa.struct([pa.field('x', pa.int16()),