This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 62b527e7 Sanitize special character column names when writing (#590)
62b527e7 is described below
commit 62b527e74cff328fde5faa848519c56748a244c1
Author: Kevin Liu <[email protected]>
AuthorDate: Wed Apr 17 04:47:02 2024 -0700
Sanitize special character column names when writing (#590)
* write with sanitized column names
* push down to when parquet writes
* add test for writing special character column name
* parameterize format_version
* use to_requested_schema
* refactor to_requested_schema
* more refactor
* test nested schema
* special character inside nested field
* comment on why arrow is enabled
* use existing variable
* move spark config to conftest
* pyspark arrow turns pandas df from tuple to dict
* Revert refactor to_requested_schema
* reorder args
* refactor
* pushdown schema
* only tranform when necessary
---
pyiceberg/io/pyarrow.py | 24 +++++++++++++++---------
tests/conftest.py | 1 +
tests/integration/test_inspect_table.py | 3 ---
tests/integration/test_writes/test_writes.py | 20 +++++++++++++++++++-
4 files changed, 35 insertions(+), 13 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 67ebaa81..f8deb2f9 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -122,6 +122,7 @@ from pyiceberg.schema import (
pre_order_visit,
promote,
prune_columns,
+ sanitize_column_names,
visit,
visit_with_partner,
)
@@ -1016,7 +1017,6 @@ def _task_to_table(
if len(arrow_table) < 1:
return None
-
return to_requested_schema(projected_schema, file_project_schema,
arrow_table)
@@ -1769,10 +1769,7 @@ def data_file_statistics_from_parquet_metadata(
def write_file(io: FileIO, table_metadata: TableMetadata, tasks:
Iterator[WriteTask]) -> Iterator[DataFile]:
- schema = table_metadata.schema()
- arrow_file_schema = schema.as_arrow()
parquet_writer_kwargs =
_get_parquet_writer_kwargs(table_metadata.properties)
-
row_group_size = PropertyUtil.property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
@@ -1780,16 +1777,25 @@ def write_file(io: FileIO, table_metadata:
TableMetadata, tasks: Iterator[WriteT
)
def write_parquet(task: WriteTask) -> DataFile:
+ table_schema = task.schema
+ arrow_table = pa.Table.from_batches(task.record_batches)
+ # if schema needs to be transformed, use the transformed schema and
adjust the arrow table accordingly
+ # otherwise use the original schema
+ if (sanitized_schema := sanitize_column_names(table_schema)) !=
table_schema:
+ file_schema = sanitized_schema
+ arrow_table = to_requested_schema(requested_schema=file_schema,
file_schema=table_schema, table=arrow_table)
+ else:
+ file_schema = table_schema
+
file_path =
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
- with pq.ParquetWriter(fos, schema=arrow_file_schema,
**parquet_writer_kwargs) as writer:
- writer.write(pa.Table.from_batches(task.record_batches),
row_group_size=row_group_size)
-
+ with pq.ParquetWriter(fos, schema=file_schema.as_arrow(),
**parquet_writer_kwargs) as writer:
+ writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
- stats_columns=compute_statistics_plan(schema,
table_metadata.properties),
- parquet_column_mapping=parquet_path_to_id_mapping(schema),
+ stats_columns=compute_statistics_plan(file_schema,
table_metadata.properties),
+ parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
)
data_file = DataFile(
content=DataFileContent.DATA,
diff --git a/tests/conftest.py b/tests/conftest.py
index e0d82910..66795436 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2060,6 +2060,7 @@ def spark() -> "SparkSession":
.config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
.config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000")
.config("spark.sql.catalog.hive.s3.path-style-access", "true")
+ .config("spark.sql.execution.arrow.pyspark.enabled", "true")
.getOrCreate()
)
diff --git a/tests/integration/test_inspect_table.py
b/tests/integration/test_inspect_table.py
index 0905eda8..24c12763 100644
--- a/tests/integration/test_inspect_table.py
+++ b/tests/integration/test_inspect_table.py
@@ -171,7 +171,6 @@ def test_inspect_entries(
for column in df.column_names:
for left, right in zip(lhs[column].to_list(),
rhs[column].to_list()):
if column == 'data_file':
- right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is
unpartitioned
@@ -185,8 +184,6 @@ def test_inspect_entries(
assert df_lhs == df_rhs, f"Difference in data_file
column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
- right = right.asDict(recursive=True)
-
assert list(left.keys()) == [
'bool',
'string',
diff --git a/tests/integration/test_writes/test_writes.py
b/tests/integration/test_writes/test_writes.py
index 775a6f9d..2cf2c9ef 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -280,9 +280,27 @@ def
test_python_writes_special_character_column_with_spark_reads(
column_name_with_special_character = "letter/abc"
TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
column_name_with_special_character: ['a', None, 'z'],
+ 'id': [1, 2, 3],
+ 'name': ['AB', 'CD', 'EF'],
+ 'address': [
+ {'street': '123', 'city': 'SFO', 'zip': 12345,
column_name_with_special_character: 'a'},
+ {'street': '456', 'city': 'SW', 'zip': 67890,
column_name_with_special_character: 'b'},
+ {'street': '789', 'city': 'Random', 'zip': 10112,
column_name_with_special_character: 'c'},
+ ],
}
pa_schema = pa.schema([
- (column_name_with_special_character, pa.string()),
+ pa.field(column_name_with_special_character, pa.string()),
+ pa.field('id', pa.int32()),
+ pa.field('name', pa.string()),
+ pa.field(
+ 'address',
+ pa.struct([
+ pa.field('street', pa.string()),
+ pa.field('city', pa.string()),
+ pa.field('zip', pa.int32()),
+ pa.field(column_name_with_special_character, pa.string()),
+ ]),
+ ),
])
arrow_table_with_special_character_column =
pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
tbl = _create_table(session_catalog, identifier, {"format-version":
format_version}, schema=pa_schema)