This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 62b527e7 Sanitize special character column names when writing  (#590)
62b527e7 is described below

commit 62b527e74cff328fde5faa848519c56748a244c1
Author: Kevin Liu <[email protected]>
AuthorDate: Wed Apr 17 04:47:02 2024 -0700

    Sanitize special character column names when writing  (#590)
    
    * write with sanitized column names
    
    * push down to when parquet writes
    
    * add test for writing special character column name
    
    * parameterize format_version
    
    * use to_requested_schema
    
    * refactor to_requested_schema
    
    * more refactor
    
    * test nested schema
    
    * special character inside nested field
    
    * comment on why arrow is enabled
    
    * use existing variable
    
    * move spark config to conftest
    
    * pyspark arrow turns pandas df from tuple to dict
    
    * Revert refactor to_requested_schema
    
    * reorder args
    
    * refactor
    
    * pushdown schema
    
    * only tranform when necessary
---
 pyiceberg/io/pyarrow.py                      | 24 +++++++++++++++---------
 tests/conftest.py                            |  1 +
 tests/integration/test_inspect_table.py      |  3 ---
 tests/integration/test_writes/test_writes.py | 20 +++++++++++++++++++-
 4 files changed, 35 insertions(+), 13 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 67ebaa81..f8deb2f9 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -122,6 +122,7 @@ from pyiceberg.schema import (
     pre_order_visit,
     promote,
     prune_columns,
+    sanitize_column_names,
     visit,
     visit_with_partner,
 )
@@ -1016,7 +1017,6 @@ def _task_to_table(
 
         if len(arrow_table) < 1:
             return None
-
         return to_requested_schema(projected_schema, file_project_schema, 
arrow_table)
 
 
@@ -1769,10 +1769,7 @@ def data_file_statistics_from_parquet_metadata(
 
 
 def write_file(io: FileIO, table_metadata: TableMetadata, tasks: 
Iterator[WriteTask]) -> Iterator[DataFile]:
-    schema = table_metadata.schema()
-    arrow_file_schema = schema.as_arrow()
     parquet_writer_kwargs = 
_get_parquet_writer_kwargs(table_metadata.properties)
-
     row_group_size = PropertyUtil.property_as_int(
         properties=table_metadata.properties,
         property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
@@ -1780,16 +1777,25 @@ def write_file(io: FileIO, table_metadata: 
TableMetadata, tasks: Iterator[WriteT
     )
 
     def write_parquet(task: WriteTask) -> DataFile:
+        table_schema = task.schema
+        arrow_table = pa.Table.from_batches(task.record_batches)
+        # if schema needs to be transformed, use the transformed schema and 
adjust the arrow table accordingly
+        # otherwise use the original schema
+        if (sanitized_schema := sanitize_column_names(table_schema)) != 
table_schema:
+            file_schema = sanitized_schema
+            arrow_table = to_requested_schema(requested_schema=file_schema, 
file_schema=table_schema, table=arrow_table)
+        else:
+            file_schema = table_schema
+
         file_path = 
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
         fo = io.new_output(file_path)
         with fo.create(overwrite=True) as fos:
-            with pq.ParquetWriter(fos, schema=arrow_file_schema, 
**parquet_writer_kwargs) as writer:
-                writer.write(pa.Table.from_batches(task.record_batches), 
row_group_size=row_group_size)
-
+            with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), 
**parquet_writer_kwargs) as writer:
+                writer.write(arrow_table, row_group_size=row_group_size)
         statistics = data_file_statistics_from_parquet_metadata(
             parquet_metadata=writer.writer.metadata,
-            stats_columns=compute_statistics_plan(schema, 
table_metadata.properties),
-            parquet_column_mapping=parquet_path_to_id_mapping(schema),
+            stats_columns=compute_statistics_plan(file_schema, 
table_metadata.properties),
+            parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
         )
         data_file = DataFile(
             content=DataFileContent.DATA,
diff --git a/tests/conftest.py b/tests/conftest.py
index e0d82910..66795436 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2060,6 +2060,7 @@ def spark() -> "SparkSession":
         .config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
         .config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000";)
         .config("spark.sql.catalog.hive.s3.path-style-access", "true")
+        .config("spark.sql.execution.arrow.pyspark.enabled", "true")
         .getOrCreate()
     )
 
diff --git a/tests/integration/test_inspect_table.py 
b/tests/integration/test_inspect_table.py
index 0905eda8..24c12763 100644
--- a/tests/integration/test_inspect_table.py
+++ b/tests/integration/test_inspect_table.py
@@ -171,7 +171,6 @@ def test_inspect_entries(
         for column in df.column_names:
             for left, right in zip(lhs[column].to_list(), 
rhs[column].to_list()):
                 if column == 'data_file':
-                    right = right.asDict(recursive=True)
                     for df_column in left.keys():
                         if df_column == 'partition':
                             # Spark leaves out the partition if the table is 
unpartitioned
@@ -185,8 +184,6 @@ def test_inspect_entries(
 
                         assert df_lhs == df_rhs, f"Difference in data_file 
column {df_column}: {df_lhs} != {df_rhs}"
                 elif column == 'readable_metrics':
-                    right = right.asDict(recursive=True)
-
                     assert list(left.keys()) == [
                         'bool',
                         'string',
diff --git a/tests/integration/test_writes/test_writes.py 
b/tests/integration/test_writes/test_writes.py
index 775a6f9d..2cf2c9ef 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -280,9 +280,27 @@ def 
test_python_writes_special_character_column_with_spark_reads(
     column_name_with_special_character = "letter/abc"
     TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
         column_name_with_special_character: ['a', None, 'z'],
+        'id': [1, 2, 3],
+        'name': ['AB', 'CD', 'EF'],
+        'address': [
+            {'street': '123', 'city': 'SFO', 'zip': 12345, 
column_name_with_special_character: 'a'},
+            {'street': '456', 'city': 'SW', 'zip': 67890, 
column_name_with_special_character: 'b'},
+            {'street': '789', 'city': 'Random', 'zip': 10112, 
column_name_with_special_character: 'c'},
+        ],
     }
     pa_schema = pa.schema([
-        (column_name_with_special_character, pa.string()),
+        pa.field(column_name_with_special_character, pa.string()),
+        pa.field('id', pa.int32()),
+        pa.field('name', pa.string()),
+        pa.field(
+            'address',
+            pa.struct([
+                pa.field('street', pa.string()),
+                pa.field('city', pa.string()),
+                pa.field('zip', pa.int32()),
+                pa.field(column_name_with_special_character, pa.string()),
+            ]),
+        ),
     ])
     arrow_table_with_special_character_column = 
pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
     tbl = _create_table(session_catalog, identifier, {"format-version": 
format_version}, schema=pa_schema)

Reply via email to