This is an automated email from the ASF dual-hosted git repository.
sungwy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 5cce906d Use `VisitorWithPartner` for name-mapping (#1014)
5cce906d is described below
commit 5cce906db89fa1edbb57bb423ba371598ce50acb
Author: Fokko Driesprong <[email protected]>
AuthorDate: Tue Aug 13 14:34:24 2024 +0200
Use `VisitorWithPartner` for name-mapping (#1014)
* Use `VisitorWithPartner` for name-mapping
This will correctly handle fields with `.` in the name.
* Fix versions in deprecation
Co-authored-by: Sung Yun <[email protected]>
* Use full path in error
---------
Co-authored-by: Sung Yun <[email protected]>
---
pyiceberg/io/pyarrow.py | 16 ++---
pyiceberg/table/name_mapping.py | 134 ++++++++++++++++++++++++++++++++++++++-
tests/table/test_name_mapping.py | 52 ++++++++++++++-
3 files changed, 189 insertions(+), 13 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 719d2897..b2cb167a 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -130,7 +130,7 @@ from pyiceberg.schema import (
visit_with_partner,
)
from pyiceberg.table.metadata import TableMetadata
-from pyiceberg.table.name_mapping import NameMapping
+from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
@@ -818,14 +818,14 @@ def pyarrow_to_schema(
) -> Schema:
has_ids = visit_pyarrow(schema, _HasIds())
if has_ids:
- visitor =
_ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ return visit_pyarrow(schema,
_ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
elif name_mapping is not None:
- visitor = _ConvertToIceberg(name_mapping=name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ schema_without_ids = _pyarrow_to_schema_without_ids(schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ return apply_name_mapping(schema_without_ids, name_mapping)
else:
raise ValueError(
"Parquet file does not have field-ids and the Iceberg table does
not have 'schema.name-mapping.default' defined"
)
- return visit_pyarrow(schema, visitor)
def _pyarrow_to_schema_without_ids(schema: pa.Schema,
downcast_ns_timestamp_to_us: bool = False) -> Schema:
@@ -1002,17 +1002,13 @@ class
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
"""Converts PyArrowSchema to Iceberg Schema. Applies the IDs from
name_mapping if provided."""
_field_names: List[str]
- _name_mapping: Optional[NameMapping]
- def __init__(self, name_mapping: Optional[NameMapping] = None,
downcast_ns_timestamp_to_us: bool = False) -> None:
+ def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None:
self._field_names = []
- self._name_mapping = name_mapping
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
def _field_id(self, field: pa.Field) -> int:
- if self._name_mapping:
- return self._name_mapping.find(*self._field_names).field_id
- elif (field_id := _get_field_id(field)) is not None:
+ if (field_id := _get_field_id(field)) is not None:
return field_id
else:
raise ValueError(f"Cannot convert {field} to Iceberg Field as
field_id is empty.")
diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py
index cb9f72bf..eaf5fc85 100644
--- a/pyiceberg/table/name_mapping.py
+++ b/pyiceberg/table/name_mapping.py
@@ -30,9 +30,10 @@ from typing import Any, Dict, Generic, Iterator, List,
Optional, TypeVar, Union
from pydantic import Field, conlist, field_validator, model_serializer
-from pyiceberg.schema import Schema, SchemaVisitor, visit
+from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor,
SchemaWithPartnerVisitor, visit, visit_with_partner
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
-from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType,
StructType
+from pyiceberg.types import IcebergType, ListType, MapType, NestedField,
PrimitiveType, StructType
+from pyiceberg.utils.deprecated import deprecated
class MappedField(IcebergBaseModel):
@@ -74,6 +75,11 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
def _field_by_name(self) -> Dict[str, MappedField]:
return visit_name_mapping(self, _IndexByName())
+ @deprecated(
+ deprecated_in="0.8.0",
+ removed_in="0.9.0",
+ help_message="Please use `apply_name_mapping` instead",
+ )
def find(self, *names: str) -> MappedField:
name = ".".join(names)
try:
@@ -248,3 +254,127 @@ def create_mapping_from_schema(schema: Schema) ->
NameMapping:
def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField],
adds: Dict[int, List[NestedField]]) -> NameMapping:
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates,
adds)))
+
+
+class NameMappingAccessor(PartnerAccessor[MappedField]):
+ def schema_partner(self, partner: Optional[MappedField]) ->
Optional[MappedField]:
+ return partner
+
+ def field_partner(
+ self, partner_struct: Optional[Union[List[MappedField], MappedField]],
_: int, field_name: str
+ ) -> Optional[MappedField]:
+ if partner_struct is not None:
+ if isinstance(partner_struct, MappedField):
+ partner_struct = partner_struct.fields
+
+ for field in partner_struct:
+ if field_name in field.names:
+ return field
+
+ return None
+
+ def list_element_partner(self, partner_list: Optional[MappedField]) ->
Optional[MappedField]:
+ if partner_list is not None:
+ for field in partner_list.fields:
+ if "element" in field.names:
+ return field
+ return None
+
+ def map_key_partner(self, partner_map: Optional[MappedField]) ->
Optional[MappedField]:
+ if partner_map is not None:
+ for field in partner_map.fields:
+ if "key" in field.names:
+ return field
+ return None
+
+ def map_value_partner(self, partner_map: Optional[MappedField]) ->
Optional[MappedField]:
+ if partner_map is not None:
+ for field in partner_map.fields:
+ if "value" in field.names:
+ return field
+ return None
+
+
+class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField,
IcebergType]):
+ current_path: List[str]
+
+ def __init__(self) -> None:
+ # For keeping track where we are in case when a field cannot be found
+ self.current_path = []
+
+ def before_field(self, field: NestedField, field_partner: Optional[P]) ->
None:
+ self.current_path.append(field.name)
+
+ def after_field(self, field: NestedField, field_partner: Optional[P]) ->
None:
+ self.current_path.pop()
+
+ def before_list_element(self, element: NestedField, element_partner:
Optional[P]) -> None:
+ self.current_path.append("element")
+
+ def after_list_element(self, element: NestedField, element_partner:
Optional[P]) -> None:
+ self.current_path.pop()
+
+ def before_map_key(self, key: NestedField, key_partner: Optional[P]) ->
None:
+ self.current_path.append("key")
+
+ def after_map_key(self, key: NestedField, key_partner: Optional[P]) ->
None:
+ self.current_path.pop()
+
+ def before_map_value(self, value: NestedField, value_partner: Optional[P])
-> None:
+ self.current_path.append("value")
+
+ def after_map_value(self, value: NestedField, value_partner: Optional[P])
-> None:
+ self.current_path.pop()
+
+ def schema(self, schema: Schema, schema_partner: Optional[MappedField],
struct_result: StructType) -> IcebergType:
+ return Schema(*struct_result.fields, schema_id=schema.schema_id)
+
+ def struct(self, struct: StructType, struct_partner:
Optional[MappedField], field_results: List[NestedField]) -> IcebergType:
+ return StructType(*field_results)
+
+ def field(self, field: NestedField, field_partner: Optional[MappedField],
field_result: IcebergType) -> IcebergType:
+ if field_partner is None:
+ raise ValueError(f"Field missing from NameMapping:
{'.'.join(self.current_path)}")
+
+ return NestedField(
+ field_id=field_partner.field_id,
+ name=field.name,
+ field_type=field_result,
+ required=field.required,
+ doc=field.doc,
+ initial_default=field.initial_default,
+ initial_write=field.write_default,
+ )
+
+ def list(self, list_type: ListType, list_partner: Optional[MappedField],
element_result: IcebergType) -> IcebergType:
+ if list_partner is None:
+ raise ValueError(f"Could not find field with name:
{'.'.join(self.current_path)}")
+
+ element_id = next(field for field in list_partner.fields if "element"
in field.names).field_id
+ return ListType(element_id=element_id, element=element_result,
element_required=list_type.element_required)
+
+ def map(
+ self, map_type: MapType, map_partner: Optional[MappedField],
key_result: IcebergType, value_result: IcebergType
+ ) -> IcebergType:
+ if map_partner is None:
+ raise ValueError(f"Could not find field with name:
{'.'.join(self.current_path)}")
+
+ key_id = next(field for field in map_partner.fields if "key" in
field.names).field_id
+ value_id = next(field for field in map_partner.fields if "value" in
field.names).field_id
+ return MapType(
+ key_id=key_id,
+ key_type=key_result,
+ value_id=value_id,
+ value_type=value_result,
+ value_required=map_type.value_required,
+ )
+
+ def primitive(self, primitive: PrimitiveType, primitive_partner:
Optional[MappedField]) -> PrimitiveType:
+ if primitive_partner is None:
+ raise ValueError(f"Could not find field with name:
{'.'.join(self.current_path)}")
+
+ return primitive
+
+
+def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping)
-> Schema:
+ return visit_with_partner(schema_without_ids, name_mapping,
NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore
diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py
index 3c50a24e..647644fa 100644
--- a/tests/table/test_name_mapping.py
+++ b/tests/table/test_name_mapping.py
@@ -20,11 +20,12 @@ from pyiceberg.schema import Schema
from pyiceberg.table.name_mapping import (
MappedField,
NameMapping,
+ apply_name_mapping,
create_mapping_from_schema,
parse_mapping_from_json,
update_mapping,
)
-from pyiceberg.types import NestedField, StringType
+from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType,
MapType, NestedField, StringType, StructType
@pytest.fixture(scope="session")
@@ -321,3 +322,52 @@ def test_update_mapping(table_name_mapping_nested:
NameMapping) -> None:
MappedField(field_id=18, names=["add_18"]),
])
assert update_mapping(table_name_mapping_nested, updates, adds) == expected
+
+
+def test_mapping_using_by_visitor(table_schema_nested: Schema,
table_name_mapping_nested: NameMapping) -> None:
+ schema_without_ids = Schema(
+ NestedField(field_id=0, name="foo", field_type=StringType(),
required=False),
+ NestedField(field_id=0, name="bar", field_type=IntegerType(),
required=True),
+ NestedField(field_id=0, name="baz", field_type=BooleanType(),
required=False),
+ NestedField(
+ field_id=0,
+ name="qux",
+ field_type=ListType(element_id=0, element_type=StringType(),
element_required=True),
+ required=True,
+ ),
+ NestedField(
+ field_id=0,
+ name="quux",
+ field_type=MapType(
+ key_id=0,
+ key_type=StringType(),
+ value_id=0,
+ value_type=MapType(key_id=0, key_type=StringType(),
value_id=0, value_type=IntegerType(), value_required=True),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=0,
+ name="location",
+ field_type=ListType(
+ element_id=0,
+ element_type=StructType(
+ NestedField(field_id=0, name="latitude",
field_type=FloatType(), required=False),
+ NestedField(field_id=0, name="longitude",
field_type=FloatType(), required=False),
+ ),
+ element_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=0,
+ name="person",
+ field_type=StructType(
+ NestedField(field_id=0, name="name", field_type=StringType(),
required=False),
+ NestedField(field_id=0, name="age", field_type=IntegerType(),
required=True),
+ ),
+ required=False,
+ ),
+ )
+ assert apply_name_mapping(schema_without_ids,
table_name_mapping_nested).fields == table_schema_nested.fields