This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new e61ef577 Add `include_field_ids` flag in `schema_to_pyarrow` (#789)
e61ef577 is described below
commit e61ef5770b4d73e683e2c78bebdd6c2165102a6b
Author: Sung Yun <[email protected]>
AuthorDate: Mon Jun 3 12:26:56 2024 -0400
Add `include_field_ids` flag in `schema_to_pyarrow` (#789)
* include_field_ids flag
* include_field_ids flag
---
pyiceberg/io/pyarrow.py | 25 +++++++++++++--------
tests/io/test_pyarrow.py | 57 ++++++++++++++++++++++++------------------------
2 files changed, 45 insertions(+), 37 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 04f30ec6..71925c27 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -469,15 +469,18 @@ class PyArrowFileIO(FileIO):
self.fs_by_scheme = lru_cache(self._initialize_fs)
-def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata:
Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
- return visit(schema, _ConvertToArrowSchema(metadata))
+def schema_to_pyarrow(
+ schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] =
EMPTY_DICT, include_field_ids: bool = True
+) -> pa.schema:
+ return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
_metadata: Dict[bytes, bytes]
- def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None:
+ def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT,
include_field_ids: bool = True) -> None:
self._metadata = metadata
+ self._include_field_ids = include_field_ids
def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
return pa.schema(list(struct_result), metadata=self._metadata)
@@ -486,13 +489,17 @@ class
_ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
return pa.struct(field_results)
def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
+ metadata = {}
+ if field.doc:
+ metadata[PYARROW_FIELD_DOC_KEY] = field.doc
+ if self._include_field_ids:
+ metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
+
return pa.field(
name=field.name,
type=field_result,
nullable=field.optional,
- metadata={PYARROW_FIELD_DOC_KEY: field.doc,
PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}
- if field.doc
- else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)},
+ metadata=metadata,
)
def list(self, list_type: ListType, element_result: pa.DataType) ->
pa.DataType:
@@ -1130,7 +1137,7 @@ def project_table(
tables = [f.result() for f in completed_futures if f.result()]
if len(tables) < 1:
- return pa.Table.from_batches([],
schema=schema_to_pyarrow(projected_schema))
+ return pa.Table.from_batches([],
schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
result = pa.concat_tables(tables)
@@ -1161,7 +1168,7 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
def _cast_if_needed(self, field: NestedField, values: pa.Array) ->
pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type !=
file_field.field_type:
- return
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
+ return
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type),
include_field_ids=False))
return values
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) ->
pa.Field:
@@ -1188,7 +1195,7 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
field_arrays.append(array)
fields.append(self._construct_field(field, array.type))
elif field.optional:
- arrow_type = schema_to_pyarrow(field.field_type)
+ arrow_type = schema_to_pyarrow(field.field_type,
include_field_ids=False)
field_arrays.append(pa.nulls(len(struct_array),
type=arrow_type))
fields.append(self._construct_field(field, arrow_type))
else:
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index ec511f95..baa9e308 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -344,7 +344,7 @@ def test_deleting_hdfs_file_not_found() -> None:
assert "Cannot delete file, does not exist:" in str(exc_info.value)
-def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
+def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested:
Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested)
expected = """foo: string
-- field metadata --
@@ -402,6 +402,30 @@ person: struct<name: string, age: int32 not null>
assert repr(actual) == expected
+def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested:
Schema) -> None:
+ actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
+ expected = """foo: string
+bar: int32 not null
+baz: bool
+qux: list<element: string not null> not null
+ child 0, element: string not null
+quux: map<string, map<string, int32>> not null
+ child 0, entries: struct<key: string not null, value: map<string, int32> not
null> not null
+ child 0, key: string not null
+ child 1, value: map<string, int32> not null
+ child 0, entries: struct<key: string not null, value: int32 not
null> not null
+ child 0, key: string not null
+ child 1, value: int32 not null
+location: list<element: struct<latitude: float, longitude: float> not null>
not null
+ child 0, element: struct<latitude: float, longitude: float> not null
+ child 0, latitude: float
+ child 1, longitude: float
+person: struct<name: string, age: int32 not null>
+ child 0, name: string
+ child 1, age: int32 not null"""
+ assert repr(actual) == expected
+
+
def test_fixed_type_to_pyarrow() -> None:
length = 22
iceberg_type = FixedType(length)
@@ -945,23 +969,13 @@ def test_projection_add_column(file_int: str) -> None:
== """id: int32
list: list<element: int32>
child 0, element: int32
- -- field metadata --
- PARQUET:field_id: '21'
map: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
- -- field metadata --
- PARQUET:field_id: '31'
child 1, value: string
- -- field metadata --
- PARQUET:field_id: '32'
location: struct<lat: double, lon: double>
child 0, lat: double
- -- field metadata --
- PARQUET:field_id: '41'
- child 1, lon: double
- -- field metadata --
- PARQUET:field_id: '42'"""
+ child 1, lon: double"""
)
@@ -1014,11 +1028,7 @@ def test_projection_add_column_struct(schema_int:
Schema, file_int: str) -> None
== """id: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
- -- field metadata --
- PARQUET:field_id: '3'
- child 1, value: string
- -- field metadata --
- PARQUET:field_id: '4'"""
+ child 1, value: string"""
)
@@ -1062,12 +1072,7 @@ def test_projection_concat_files(schema_int: Schema,
file_int: str) -> None:
def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
- assert (
- repr(result_table.schema)
- == """id: int32
- -- field metadata --
- PARQUET:field_id: '1'"""
- )
+ assert repr(result_table.schema) == """id: int32"""
def test_projection_filter_renamed_column(file_int: str) -> None:
@@ -1304,11 +1309,7 @@ def
test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
repr(result_table.schema)
== """location: struct<lat: double, long: double>
child 0, lat: double
- -- field metadata --
- PARQUET:field_id: '41'
- child 1, long: double
- -- field metadata --
- PARQUET:field_id: '42'"""
+ child 1, long: double"""
)