This is an automated email from the ASF dual-hosted git repository.

sungwy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 5cce906d Use `VisitorWithPartner` for name-mapping (#1014)
5cce906d is described below

commit 5cce906db89fa1edbb57bb423ba371598ce50acb
Author: Fokko Driesprong <[email protected]>
AuthorDate: Tue Aug 13 14:34:24 2024 +0200

    Use `VisitorWithPartner` for name-mapping (#1014)
    
    * Use `VisitorWithPartner` for name-mapping
    
    This will correctly handle fields with `.` in the name.
    
    * Fix versions in deprecation
    
    Co-authored-by: Sung Yun <[email protected]>
    
    * Use full path in error
    
    ---------
    
    Co-authored-by: Sung Yun <[email protected]>
---
 pyiceberg/io/pyarrow.py          |  16 ++---
 pyiceberg/table/name_mapping.py  | 134 ++++++++++++++++++++++++++++++++++++++-
 tests/table/test_name_mapping.py |  52 ++++++++++++++-
 3 files changed, 189 insertions(+), 13 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 719d2897..b2cb167a 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -130,7 +130,7 @@ from pyiceberg.schema import (
     visit_with_partner,
 )
 from pyiceberg.table.metadata import TableMetadata
-from pyiceberg.table.name_mapping import NameMapping
+from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
 from pyiceberg.transforms import TruncateTransform
 from pyiceberg.typedef import EMPTY_DICT, Properties, Record
 from pyiceberg.types import (
@@ -818,14 +818,14 @@ def pyarrow_to_schema(
 ) -> Schema:
     has_ids = visit_pyarrow(schema, _HasIds())
     if has_ids:
-        visitor = 
_ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+        return visit_pyarrow(schema, 
_ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
     elif name_mapping is not None:
-        visitor = _ConvertToIceberg(name_mapping=name_mapping, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+        schema_without_ids = _pyarrow_to_schema_without_ids(schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+        return apply_name_mapping(schema_without_ids, name_mapping)
     else:
         raise ValueError(
             "Parquet file does not have field-ids and the Iceberg table does 
not have 'schema.name-mapping.default' defined"
         )
-    return visit_pyarrow(schema, visitor)
 
 
 def _pyarrow_to_schema_without_ids(schema: pa.Schema, 
downcast_ns_timestamp_to_us: bool = False) -> Schema:
@@ -1002,17 +1002,13 @@ class 
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
     """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from 
name_mapping if provided."""
 
     _field_names: List[str]
-    _name_mapping: Optional[NameMapping]
 
-    def __init__(self, name_mapping: Optional[NameMapping] = None, 
downcast_ns_timestamp_to_us: bool = False) -> None:
+    def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None:
         self._field_names = []
-        self._name_mapping = name_mapping
         self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
 
     def _field_id(self, field: pa.Field) -> int:
-        if self._name_mapping:
-            return self._name_mapping.find(*self._field_names).field_id
-        elif (field_id := _get_field_id(field)) is not None:
+        if (field_id := _get_field_id(field)) is not None:
             return field_id
         else:
             raise ValueError(f"Cannot convert {field} to Iceberg Field as 
field_id is empty.")
diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py
index cb9f72bf..eaf5fc85 100644
--- a/pyiceberg/table/name_mapping.py
+++ b/pyiceberg/table/name_mapping.py
@@ -30,9 +30,10 @@ from typing import Any, Dict, Generic, Iterator, List, 
Optional, TypeVar, Union
 
 from pydantic import Field, conlist, field_validator, model_serializer
 
-from pyiceberg.schema import Schema, SchemaVisitor, visit
+from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, 
SchemaWithPartnerVisitor, visit, visit_with_partner
 from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
-from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, 
StructType
+from pyiceberg.types import IcebergType, ListType, MapType, NestedField, 
PrimitiveType, StructType
+from pyiceberg.utils.deprecated import deprecated
 
 
 class MappedField(IcebergBaseModel):
@@ -74,6 +75,11 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
     def _field_by_name(self) -> Dict[str, MappedField]:
         return visit_name_mapping(self, _IndexByName())
 
+    @deprecated(
+        deprecated_in="0.8.0",
+        removed_in="0.9.0",
+        help_message="Please use `apply_name_mapping` instead",
+    )
     def find(self, *names: str) -> MappedField:
         name = ".".join(names)
         try:
@@ -248,3 +254,127 @@ def create_mapping_from_schema(schema: Schema) -> 
NameMapping:
 
 def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], 
adds: Dict[int, List[NestedField]]) -> NameMapping:
     return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, 
adds)))
+
+
+class NameMappingAccessor(PartnerAccessor[MappedField]):
+    def schema_partner(self, partner: Optional[MappedField]) -> 
Optional[MappedField]:
+        return partner
+
+    def field_partner(
+        self, partner_struct: Optional[Union[List[MappedField], MappedField]], 
_: int, field_name: str
+    ) -> Optional[MappedField]:
+        if partner_struct is not None:
+            if isinstance(partner_struct, MappedField):
+                partner_struct = partner_struct.fields
+
+            for field in partner_struct:
+                if field_name in field.names:
+                    return field
+
+        return None
+
+    def list_element_partner(self, partner_list: Optional[MappedField]) -> 
Optional[MappedField]:
+        if partner_list is not None:
+            for field in partner_list.fields:
+                if "element" in field.names:
+                    return field
+        return None
+
+    def map_key_partner(self, partner_map: Optional[MappedField]) -> 
Optional[MappedField]:
+        if partner_map is not None:
+            for field in partner_map.fields:
+                if "key" in field.names:
+                    return field
+        return None
+
+    def map_value_partner(self, partner_map: Optional[MappedField]) -> 
Optional[MappedField]:
+        if partner_map is not None:
+            for field in partner_map.fields:
+                if "value" in field.names:
+                    return field
+        return None
+
+
+class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, 
IcebergType]):
+    current_path: List[str]
+
+    def __init__(self) -> None:
+        # For keeping track where we are in case when a field cannot be found
+        self.current_path = []
+
+    def before_field(self, field: NestedField, field_partner: Optional[P]) -> 
None:
+        self.current_path.append(field.name)
+
+    def after_field(self, field: NestedField, field_partner: Optional[P]) -> 
None:
+        self.current_path.pop()
+
+    def before_list_element(self, element: NestedField, element_partner: 
Optional[P]) -> None:
+        self.current_path.append("element")
+
+    def after_list_element(self, element: NestedField, element_partner: 
Optional[P]) -> None:
+        self.current_path.pop()
+
+    def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> 
None:
+        self.current_path.append("key")
+
+    def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> 
None:
+        self.current_path.pop()
+
+    def before_map_value(self, value: NestedField, value_partner: Optional[P]) 
-> None:
+        self.current_path.append("value")
+
+    def after_map_value(self, value: NestedField, value_partner: Optional[P]) 
-> None:
+        self.current_path.pop()
+
+    def schema(self, schema: Schema, schema_partner: Optional[MappedField], 
struct_result: StructType) -> IcebergType:
+        return Schema(*struct_result.fields, schema_id=schema.schema_id)
+
+    def struct(self, struct: StructType, struct_partner: 
Optional[MappedField], field_results: List[NestedField]) -> IcebergType:
+        return StructType(*field_results)
+
+    def field(self, field: NestedField, field_partner: Optional[MappedField], 
field_result: IcebergType) -> IcebergType:
+        if field_partner is None:
+            raise ValueError(f"Field missing from NameMapping: 
{'.'.join(self.current_path)}")
+
+        return NestedField(
+            field_id=field_partner.field_id,
+            name=field.name,
+            field_type=field_result,
+            required=field.required,
+            doc=field.doc,
+            initial_default=field.initial_default,
+            initial_write=field.write_default,
+        )
+
+    def list(self, list_type: ListType, list_partner: Optional[MappedField], 
element_result: IcebergType) -> IcebergType:
+        if list_partner is None:
+            raise ValueError(f"Could not find field with name: 
{'.'.join(self.current_path)}")
+
+        element_id = next(field for field in list_partner.fields if "element" 
in field.names).field_id
+        return ListType(element_id=element_id, element=element_result, 
element_required=list_type.element_required)
+
+    def map(
+        self, map_type: MapType, map_partner: Optional[MappedField], 
key_result: IcebergType, value_result: IcebergType
+    ) -> IcebergType:
+        if map_partner is None:
+            raise ValueError(f"Could not find field with name: 
{'.'.join(self.current_path)}")
+
+        key_id = next(field for field in map_partner.fields if "key" in 
field.names).field_id
+        value_id = next(field for field in map_partner.fields if "value" in 
field.names).field_id
+        return MapType(
+            key_id=key_id,
+            key_type=key_result,
+            value_id=value_id,
+            value_type=value_result,
+            value_required=map_type.value_required,
+        )
+
+    def primitive(self, primitive: PrimitiveType, primitive_partner: 
Optional[MappedField]) -> PrimitiveType:
+        if primitive_partner is None:
+            raise ValueError(f"Could not find field with name: 
{'.'.join(self.current_path)}")
+
+        return primitive
+
+
+def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) 
-> Schema:
+    return visit_with_partner(schema_without_ids, name_mapping, 
NameMappingProjectionVisitor(), NameMappingAccessor())  # type: ignore
diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py
index 3c50a24e..647644fa 100644
--- a/tests/table/test_name_mapping.py
+++ b/tests/table/test_name_mapping.py
@@ -20,11 +20,12 @@ from pyiceberg.schema import Schema
 from pyiceberg.table.name_mapping import (
     MappedField,
     NameMapping,
+    apply_name_mapping,
     create_mapping_from_schema,
     parse_mapping_from_json,
     update_mapping,
 )
-from pyiceberg.types import NestedField, StringType
+from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, 
MapType, NestedField, StringType, StructType
 
 
 @pytest.fixture(scope="session")
@@ -321,3 +322,52 @@ def test_update_mapping(table_name_mapping_nested: 
NameMapping) -> None:
         MappedField(field_id=18, names=["add_18"]),
     ])
     assert update_mapping(table_name_mapping_nested, updates, adds) == expected
+
+
+def test_mapping_using_by_visitor(table_schema_nested: Schema, 
table_name_mapping_nested: NameMapping) -> None:
+    schema_without_ids = Schema(
+        NestedField(field_id=0, name="foo", field_type=StringType(), 
required=False),
+        NestedField(field_id=0, name="bar", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=0, name="baz", field_type=BooleanType(), 
required=False),
+        NestedField(
+            field_id=0,
+            name="qux",
+            field_type=ListType(element_id=0, element_type=StringType(), 
element_required=True),
+            required=True,
+        ),
+        NestedField(
+            field_id=0,
+            name="quux",
+            field_type=MapType(
+                key_id=0,
+                key_type=StringType(),
+                value_id=0,
+                value_type=MapType(key_id=0, key_type=StringType(), 
value_id=0, value_type=IntegerType(), value_required=True),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=0,
+            name="location",
+            field_type=ListType(
+                element_id=0,
+                element_type=StructType(
+                    NestedField(field_id=0, name="latitude", 
field_type=FloatType(), required=False),
+                    NestedField(field_id=0, name="longitude", 
field_type=FloatType(), required=False),
+                ),
+                element_required=True,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=0,
+            name="person",
+            field_type=StructType(
+                NestedField(field_id=0, name="name", field_type=StringType(), 
required=False),
+                NestedField(field_id=0, name="age", field_type=IntegerType(), 
required=True),
+            ),
+            required=False,
+        ),
+    )
+    assert apply_name_mapping(schema_without_ids, 
table_name_mapping_nested).fields == table_schema_nested.fields

Reply via email to