This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 2a27f2b Arrow: Don't copy the list/map when not needed (#252)
2a27f2b is described below
commit 2a27f2b66ef76efcf0996cd0fcc52eaad58b9cca
Author: Fokko Driesprong <[email protected]>
AuthorDate: Fri Jan 26 08:52:44 2024 +0100
Arrow: Don't copy the list/map when not needed (#252)
---
dev/provision.py | 22 +++++++
pyiceberg/io/pyarrow.py | 51 ++++++++++------
tests/integration/test_reads.py | 13 ++++
tests/io/test_pyarrow.py | 129 +++++++++++++++++++++++++++++++++++-----
4 files changed, 183 insertions(+), 32 deletions(-)
diff --git a/dev/provision.py b/dev/provision.py
index e5048d2..44086ca 100644
--- a/dev/provision.py
+++ b/dev/provision.py
@@ -320,3 +320,25 @@ for catalog_name, catalog in catalogs.items():
spark.sql(f"ALTER TABLE {catalog_name}.default.test_table_add_column ADD
COLUMN b string")
spark.sql(f"INSERT INTO {catalog_name}.default.test_table_add_column
VALUES ('2', '2')")
+
+ spark.sql(
+ f"""
+ CREATE TABLE {catalog_name}.default.test_table_empty_list_and_map (
+ col_list array<int>,
+ col_map map<int, int>,
+ col_list_with_struct array<struct<test:int>>
+ )
+ USING iceberg
+ TBLPROPERTIES (
+ 'format-version'='1'
+ );
+ """
+ )
+
+ spark.sql(
+ f"""
+ INSERT INTO {catalog_name}.default.test_table_empty_list_and_map
+ VALUES (null, null, null),
+ (array(), map(), array(struct(1)))
+ """
+ )
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 035f5e8..cbfb9f6 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -168,6 +168,7 @@ PYARROW_FIELD_DOC_KEY = b"doc"
LIST_ELEMENT_NAME = "element"
MAP_KEY_NAME = "key"
MAP_VALUE_NAME = "value"
+DOC = "doc"
T = TypeVar("T")
@@ -1118,12 +1119,20 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
def __init__(self, file_schema: Schema):
self.file_schema = file_schema
- def cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
+ def _cast_if_needed(self, field: NestedField, values: pa.Array) ->
pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type !=
file_field.field_type:
return
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
return values
+ def _construct_field(self, field: NestedField, arrow_type: pa.DataType) ->
pa.Field:
+ return pa.field(
+ name=field.name,
+ type=arrow_type,
+ nullable=field.optional,
+ metadata={DOC: field.doc} if field.doc is not None else None,
+ )
+
def schema(self, schema: Schema, schema_partner: Optional[pa.Array],
struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
return struct_result
@@ -1136,13 +1145,13 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
fields: List[pa.Field] = []
for field, field_array in zip(struct.fields, field_results):
if field_array is not None:
- array = self.cast_if_needed(field, field_array)
+ array = self._cast_if_needed(field, field_array)
field_arrays.append(array)
- fields.append(pa.field(field.name, array.type, field.optional))
+ fields.append(self._construct_field(field, array.type))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type)
field_arrays.append(pa.nulls(len(struct_array),
type=arrow_type))
- fields.append(pa.field(field.name, arrow_type, field.optional))
+ fields.append(self._construct_field(field, arrow_type))
else:
raise ResolveError(f"Field is required, and could not be found
in the file: {field}")
@@ -1152,24 +1161,32 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
return field_array
def list(self, list_type: ListType, list_array: Optional[pa.Array],
value_array: Optional[pa.Array]) -> Optional[pa.Array]:
- return (
- pa.ListArray.from_arrays(list_array.offsets,
self.cast_if_needed(list_type.element_field, value_array))
- if isinstance(list_array, pa.ListArray)
- else None
- )
+ if isinstance(list_array, pa.ListArray) and value_array is not None:
+ if isinstance(value_array, pa.StructArray):
+ # This can be removed once this has been fixed:
+ # https://github.com/apache/arrow/issues/38809
+ list_array = pa.ListArray.from_arrays(list_array.offsets,
value_array)
+
+ arrow_field =
pa.list_(self._construct_field(list_type.element_field, value_array.type))
+ return list_array.cast(arrow_field)
+ else:
+ return None
def map(
self, map_type: MapType, map_array: Optional[pa.Array], key_result:
Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
- return (
- pa.MapArray.from_arrays(
- map_array.offsets,
- self.cast_if_needed(map_type.key_field, key_result),
- self.cast_if_needed(map_type.value_field, value_result),
+ if isinstance(map_array, pa.MapArray) and key_result is not None and
value_result is not None:
+ arrow_field = pa.map_(
+ self._construct_field(map_type.key_field, key_result.type),
+ self._construct_field(map_type.value_field, value_result.type),
)
- if isinstance(map_array, pa.MapArray)
- else None
- )
+ if isinstance(value_result, pa.StructArray):
+ # Arrow does not allow reordering of fields, therefore we have
to copy the array :(
+ return pa.MapArray.from_arrays(map_array.offsets, key_result,
value_result, arrow_field)
+ else:
+ return map_array.cast(arrow_field)
+ else:
+ return None
def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) ->
Optional[pa.Array]:
return array
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index e7c8b74..3fc06fb 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -428,3 +428,16 @@ def test_sanitize_character(catalog: Catalog) -> None:
assert len(arrow_table.schema.names), 1
assert len(table_test_table_sanitized_character.schema().fields), 1
assert arrow_table.schema.names[0] ==
table_test_table_sanitized_character.schema().fields[0].name
+
+
[email protected]
[email protected]('catalog', [pytest.lazy_fixture('catalog_hive'),
pytest.lazy_fixture('catalog_rest')])
+def test_null_list_and_map(catalog: Catalog) -> None:
+ table_test_empty_list_and_map =
catalog.load_table("default.test_table_empty_list_and_map")
+ arrow_table = table_test_empty_list_and_map.scan().to_arrow()
+ assert arrow_table["col_list"].to_pylist() == [None, []]
+ assert arrow_table["col_map"].to_pylist() == [None, []]
+ # This should be:
+ # assert arrow_table["col_list_with_struct"].to_pylist() == [None,
[{'test': 1}]]
+ # Once https://github.com/apache/arrow/issues/38809 has been fixed
+ assert arrow_table["col_list_with_struct"].to_pylist() == [[], [{'test':
1}]]
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 5efeb42..e6f4de3 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -682,6 +682,24 @@ def schema_list_of_structs() -> Schema:
)
[email protected]
+def schema_map_of_structs() -> Schema:
+ return Schema(
+ NestedField(
+ 5,
+ "locations",
+ MapType(
+ key_id=51,
+ value_id=52,
+ key_type=StringType(),
+ value_type=StructType(NestedField(511, "lat", DoubleType()),
NestedField(512, "long", DoubleType())),
+ element_required=False,
+ ),
+ required=False,
+ ),
+ )
+
+
@pytest.fixture
def schema_map() -> Schema:
return Schema(
@@ -793,6 +811,25 @@ def file_list_of_structs(schema_list_of_structs: Schema,
tmpdir: str) -> str:
)
[email protected]
+def file_map_of_structs(schema_map_of_structs: Schema, tmpdir: str) -> str:
+ pyarrow_schema = schema_to_pyarrow(
+ schema_map_of_structs, metadata={ICEBERG_SCHEMA:
bytes(schema_map_of_structs.model_dump_json(), UTF8)}
+ )
+ return _write_table_to_file(
+ f"file:{tmpdir}/e.parquet",
+ pyarrow_schema,
+ pa.Table.from_pylist(
+ [
+ {"locations": {"1": {"lat": 52.371807, "long": 4.896029}, "2":
{"lat": 52.387386, "long": 4.646219}}},
+ {"locations": {}},
+ {"locations": {"3": {"lat": 52.078663, "long": 4.288788}, "4":
{"lat": 52.387386, "long": 4.646219}}},
+ ],
+ schema=pyarrow_schema,
+ ),
+ )
+
+
@pytest.fixture
def file_map(schema_map: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_map, metadata={ICEBERG_SCHEMA:
bytes(schema_map.model_dump_json(), UTF8)})
@@ -914,7 +951,11 @@ def test_read_list(schema_list: Schema, file_list: str) ->
None:
for actual, expected in zip(result_table.columns[0], [list(range(1, 10)),
list(range(2, 20)), list(range(3, 30))]):
assert actual.as_py() == expected
- assert repr(result_table.schema) == "ids: list<item: int32>\n child 0,
item: int32"
+ assert (
+ repr(result_table.schema)
+ == """ids: list<element: int32>
+ child 0, element: int32"""
+ )
def test_read_map(schema_map: Schema, file_map: str) -> None:
@@ -927,9 +968,9 @@ def test_read_map(schema_map: Schema, file_map: str) ->
None:
assert (
repr(result_table.schema)
== """properties: map<string, string>
- child 0, entries: struct<key: string not null, value: string> not null
+ child 0, entries: struct<key: string not null, value: string not null> not
null
child 0, key: string not null
- child 1, value: string"""
+ child 1, value: string not null"""
)
@@ -1063,7 +1104,11 @@ def test_projection_nested_struct_subset(file_struct:
str) -> None:
assert actual.as_py() == {"lat": expected}
assert len(result_table.columns[0]) == 3
- assert repr(result_table.schema) == "location: struct<lat: double not
null> not null\n child 0, lat: double not null"
+ assert (
+ repr(result_table.schema)
+ == """location: struct<lat: double not null> not null
+ child 0, lat: double not null"""
+ )
def test_projection_nested_new_field(file_struct: str) -> None:
@@ -1082,7 +1127,11 @@ def test_projection_nested_new_field(file_struct: str)
-> None:
for actual, expected in zip(result_table.columns[0], [None, None, None]):
assert actual.as_py() == {"null": expected}
assert len(result_table.columns[0]) == 3
- assert repr(result_table.schema) == "location: struct<null: double> not
null\n child 0, null: double"
+ assert (
+ repr(result_table.schema)
+ == """location: struct<null: double> not null
+ child 0, null: double"""
+ )
def test_projection_nested_struct(schema_struct: Schema, file_struct: str) ->
None:
@@ -1111,7 +1160,10 @@ def test_projection_nested_struct(schema_struct: Schema,
file_struct: str) -> No
assert len(result_table.columns[0]) == 3
assert (
repr(result_table.schema)
- == "location: struct<lat: double, null: double, long: double> not
null\n child 0, lat: double\n child 1, null: double\n child 2, long: double"
+ == """location: struct<lat: double, null: double, long: double> not
null
+ child 0, lat: double
+ child 1, null: double
+ child 2, long: double"""
)
@@ -1136,28 +1188,75 @@ def
test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of
result_table = project(schema, [file_list_of_structs])
assert len(result_table.columns) == 1
assert len(result_table.columns[0]) == 3
+ results = [row.as_py() for row in result_table.columns[0]]
+ assert results == [
+ [
+ {'latitude': 52.371807, 'longitude': 4.896029, 'altitude': None},
+ {'latitude': 52.387386, 'longitude': 4.646219, 'altitude': None},
+ ],
+ [],
+ [
+ {'latitude': 52.078663, 'longitude': 4.288788, 'altitude': None},
+ {'latitude': 52.387386, 'longitude': 4.646219, 'altitude': None},
+ ],
+ ]
+ assert (
+ repr(result_table.schema)
+ == """locations: list<element: struct<latitude: double not null,
longitude: double not null, altitude: double>>
+ child 0, element: struct<latitude: double not null, longitude: double not
null, altitude: double>
+ child 0, latitude: double not null
+ child 1, longitude: double not null
+ child 2, altitude: double"""
+ )
+
+
+def test_projection_maps_of_structs(schema_map_of_structs: Schema,
file_map_of_structs: str) -> None:
+ schema = Schema(
+ NestedField(
+ 5,
+ "locations",
+ MapType(
+ key_id=51,
+ value_id=52,
+ key_type=StringType(),
+ value_type=StructType(
+ NestedField(511, "latitude", DoubleType()),
+ NestedField(512, "longitude", DoubleType()),
+ NestedField(513, "altitude", DoubleType(), required=False),
+ ),
+ element_required=False,
+ ),
+ required=False,
+ ),
+ )
+
+ result_table = project(schema, [file_map_of_structs])
+ assert len(result_table.columns) == 1
+ assert len(result_table.columns[0]) == 3
for actual, expected in zip(
result_table.columns[0],
[
[
- {"latitude": 52.371807, "longitude": 4.896029, "altitude":
None},
- {"latitude": 52.387386, "longitude": 4.646219, "altitude":
None},
+ ("1", {"latitude": 52.371807, "longitude": 4.896029,
"altitude": None}),
+ ("2", {"latitude": 52.387386, "longitude": 4.646219,
"altitude": None}),
],
[],
[
- {"latitude": 52.078663, "longitude": 4.288788, "altitude":
None},
- {"latitude": 52.387386, "longitude": 4.646219, "altitude":
None},
+ ("3", {"latitude": 52.078663, "longitude": 4.288788,
"altitude": None}),
+ ("4", {"latitude": 52.387386, "longitude": 4.646219,
"altitude": None}),
],
],
):
assert actual.as_py() == expected
assert (
repr(result_table.schema)
- == """locations: list<item: struct<latitude: double not null,
longitude: double not null, altitude: double>>
- child 0, item: struct<latitude: double not null, longitude: double not null,
altitude: double>
- child 0, latitude: double not null
- child 1, longitude: double not null
- child 2, altitude: double"""
+ == """locations: map<string, struct<latitude: double not null,
longitude: double not null, altitude: double>>
+ child 0, entries: struct<key: string not null, value: struct<latitude:
double not null, longitude: double not null, altitude: double> not null> not
null
+ child 0, key: string not null
+ child 1, value: struct<latitude: double not null, longitude: double not
null, altitude: double> not null
+ child 0, latitude: double not null
+ child 1, longitude: double not null
+ child 2, altitude: double"""
)