This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e61ef577 Add `include_field_ids` flag in `schema_to_pyarrow` (#789)
e61ef577 is described below

commit e61ef5770b4d73e683e2c78bebdd6c2165102a6b
Author: Sung Yun <[email protected]>
AuthorDate: Mon Jun 3 12:26:56 2024 -0400

    Add `include_field_ids` flag in `schema_to_pyarrow` (#789)
    
    * include_field_ids flag
    
    * include_field_ids flag
---
 pyiceberg/io/pyarrow.py  | 25 +++++++++++++--------
 tests/io/test_pyarrow.py | 57 ++++++++++++++++++++++++------------------------
 2 files changed, 45 insertions(+), 37 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 04f30ec6..71925c27 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -469,15 +469,18 @@ class PyArrowFileIO(FileIO):
         self.fs_by_scheme = lru_cache(self._initialize_fs)
 
 
-def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: 
Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
-    return visit(schema, _ConvertToArrowSchema(metadata))
+def schema_to_pyarrow(
+    schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = 
EMPTY_DICT, include_field_ids: bool = True
+) -> pa.schema:
+    return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
 
 
 class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
     _metadata: Dict[bytes, bytes]
 
-    def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None:
+    def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, 
include_field_ids: bool = True) -> None:
         self._metadata = metadata
+        self._include_field_ids = include_field_ids
 
     def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
         return pa.schema(list(struct_result), metadata=self._metadata)
@@ -486,13 +489,17 @@ class 
_ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
         return pa.struct(field_results)
 
     def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
+        metadata = {}
+        if field.doc:
+            metadata[PYARROW_FIELD_DOC_KEY] = field.doc
+        if self._include_field_ids:
+            metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
+
         return pa.field(
             name=field.name,
             type=field_result,
             nullable=field.optional,
-            metadata={PYARROW_FIELD_DOC_KEY: field.doc, 
PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}
-            if field.doc
-            else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)},
+            metadata=metadata,
         )
 
     def list(self, list_type: ListType, element_result: pa.DataType) -> 
pa.DataType:
@@ -1130,7 +1137,7 @@ def project_table(
     tables = [f.result() for f in completed_futures if f.result()]
 
     if len(tables) < 1:
-        return pa.Table.from_batches([], 
schema=schema_to_pyarrow(projected_schema))
+        return pa.Table.from_batches([], 
schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
 
     result = pa.concat_tables(tables)
 
@@ -1161,7 +1168,7 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
     def _cast_if_needed(self, field: NestedField, values: pa.Array) -> 
pa.Array:
         file_field = self.file_schema.find_field(field.field_id)
         if field.field_type.is_primitive and field.field_type != 
file_field.field_type:
-            return 
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
+            return 
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), 
include_field_ids=False))
         return values
 
     def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> 
pa.Field:
@@ -1188,7 +1195,7 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
                 field_arrays.append(array)
                 fields.append(self._construct_field(field, array.type))
             elif field.optional:
-                arrow_type = schema_to_pyarrow(field.field_type)
+                arrow_type = schema_to_pyarrow(field.field_type, 
include_field_ids=False)
                 field_arrays.append(pa.nulls(len(struct_array), 
type=arrow_type))
                 fields.append(self._construct_field(field, arrow_type))
             else:
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index ec511f95..baa9e308 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -344,7 +344,7 @@ def test_deleting_hdfs_file_not_found() -> None:
         assert "Cannot delete file, does not exist:" in str(exc_info.value)
 
 
-def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
+def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: 
Schema) -> None:
     actual = schema_to_pyarrow(table_schema_nested)
     expected = """foo: string
   -- field metadata --
@@ -402,6 +402,30 @@ person: struct<name: string, age: int32 not null>
     assert repr(actual) == expected
 
 
+def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: 
Schema) -> None:
+    actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
+    expected = """foo: string
+bar: int32 not null
+baz: bool
+qux: list<element: string not null> not null
+  child 0, element: string not null
+quux: map<string, map<string, int32>> not null
+  child 0, entries: struct<key: string not null, value: map<string, int32> not 
null> not null
+      child 0, key: string not null
+      child 1, value: map<string, int32> not null
+          child 0, entries: struct<key: string not null, value: int32 not 
null> not null
+              child 0, key: string not null
+              child 1, value: int32 not null
+location: list<element: struct<latitude: float, longitude: float> not null> 
not null
+  child 0, element: struct<latitude: float, longitude: float> not null
+      child 0, latitude: float
+      child 1, longitude: float
+person: struct<name: string, age: int32 not null>
+  child 0, name: string
+  child 1, age: int32 not null"""
+    assert repr(actual) == expected
+
+
 def test_fixed_type_to_pyarrow() -> None:
     length = 22
     iceberg_type = FixedType(length)
@@ -945,23 +969,13 @@ def test_projection_add_column(file_int: str) -> None:
         == """id: int32
 list: list<element: int32>
   child 0, element: int32
-    -- field metadata --
-    PARQUET:field_id: '21'
 map: map<int32, string>
   child 0, entries: struct<key: int32 not null, value: string> not null
       child 0, key: int32 not null
-      -- field metadata --
-      PARQUET:field_id: '31'
       child 1, value: string
-      -- field metadata --
-      PARQUET:field_id: '32'
 location: struct<lat: double, lon: double>
   child 0, lat: double
-    -- field metadata --
-    PARQUET:field_id: '41'
-  child 1, lon: double
-    -- field metadata --
-    PARQUET:field_id: '42'"""
+  child 1, lon: double"""
     )
 
 
@@ -1014,11 +1028,7 @@ def test_projection_add_column_struct(schema_int: 
Schema, file_int: str) -> None
         == """id: map<int32, string>
   child 0, entries: struct<key: int32 not null, value: string> not null
       child 0, key: int32 not null
-      -- field metadata --
-      PARQUET:field_id: '3'
-      child 1, value: string
-      -- field metadata --
-      PARQUET:field_id: '4'"""
+      child 1, value: string"""
     )
 
 
@@ -1062,12 +1072,7 @@ def test_projection_concat_files(schema_int: Schema, 
file_int: str) -> None:
 def test_projection_filter(schema_int: Schema, file_int: str) -> None:
     result_table = project(schema_int, [file_int], GreaterThan("id", 4))
     assert len(result_table.columns[0]) == 0
-    assert (
-        repr(result_table.schema)
-        == """id: int32
-  -- field metadata --
-  PARQUET:field_id: '1'"""
-    )
+    assert repr(result_table.schema) == """id: int32"""
 
 
 def test_projection_filter_renamed_column(file_int: str) -> None:
@@ -1304,11 +1309,7 @@ def 
test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
         repr(result_table.schema)
         == """location: struct<lat: double, long: double>
   child 0, lat: double
-    -- field metadata --
-    PARQUET:field_id: '41'
-  child 1, long: double
-    -- field metadata --
-    PARQUET:field_id: '42'"""
+  child 1, long: double"""
     )
 
 

Reply via email to