This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new d4a4eede Cast PyArrow schema to `large_*` types (#807)
d4a4eede is described below
commit d4a4eedee247c6a14f31383b0e43a91169d28539
Author: Sung Yun <[email protected]>
AuthorDate: Fri Jun 14 16:13:19 2024 -0400
Cast PyArrow schema to `large_*` types (#807)
* _pyarrow_with
* fix
* fix test
* adopt review feedback
* revert accidental conf change
* adopt-nit
---
pyiceberg/io/pyarrow.py | 56 ++++++++++++---
tests/catalog/test_sql.py | 10 +--
tests/conftest.py | 4 +-
tests/integration/test_writes/test_writes.py | 54 ++++++++++++++
tests/io/test_pyarrow.py | 102 +++++++++++++--------------
tests/io/test_pyarrow_visitor.py | 39 +++++++++-
tests/test_schema.py | 2 +-
7 files changed, 197 insertions(+), 70 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 71925c27..935b78ce 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -504,7 +504,7 @@ class
_ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
def list(self, list_type: ListType, element_result: pa.DataType) ->
pa.DataType:
element_field = self.field(list_type.element_field, element_result)
- return pa.list_(value_type=element_field)
+ return pa.large_list(value_type=element_field)
def map(self, map_type: MapType, key_result: pa.DataType, value_result:
pa.DataType) -> pa.DataType:
key_field = self.field(map_type.key_field, key_result)
@@ -548,7 +548,7 @@ class
_ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
return pa.timestamp(unit="us", tz="UTC")
def visit_string(self, _: StringType) -> pa.DataType:
- return pa.string()
+ return pa.large_string()
def visit_uuid(self, _: UUIDType) -> pa.DataType:
return pa.binary(16)
@@ -680,6 +680,10 @@ def _pyarrow_to_schema_without_ids(schema: pa.Schema) ->
Schema:
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
+def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
+ return visit_pyarrow(schema, _ConvertToLargeTypes())
+
+
@singledispatch
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor:
PyArrowSchemaVisitor[T]) -> T:
"""Apply a pyarrow schema visitor to any point within a schema.
@@ -952,6 +956,30 @@ class
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
self._field_names.pop()
+class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType,
pa.Schema]]):
+ def schema(self, schema: pa.Schema, struct_result: pa.StructType) ->
pa.Schema:
+ return pa.schema(struct_result)
+
+ def struct(self, struct: pa.StructType, field_results: List[pa.Field]) ->
pa.StructType:
+ return pa.struct(field_results)
+
+ def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
+ return field.with_type(field_result)
+
+ def list(self, list_type: pa.ListType, element_result: pa.DataType) ->
pa.DataType:
+ return pa.large_list(element_result)
+
+ def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result:
pa.DataType) -> pa.DataType:
+ return pa.map_(key_result, value_result)
+
+ def primitive(self, primitive: pa.DataType) -> pa.DataType:
+ if primitive == pa.string():
+ return pa.large_string()
+ elif primitive == pa.binary():
+ return pa.large_binary()
+ return primitive
+
+
class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
"""
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
@@ -998,7 +1026,9 @@ def _task_to_table(
fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
- schema=physical_schema,
+ # We always use large types in memory as it uses larger offsets
+ # That can chunk more row values into the buffers
+ schema=_pyarrow_schema_ensure_large_types(physical_schema),
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them
first
filter=pyarrow_filter if not positional_deletes else None,
@@ -1167,8 +1197,14 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
def _cast_if_needed(self, field: NestedField, values: pa.Array) ->
pa.Array:
file_field = self.file_schema.find_field(field.field_id)
- if field.field_type.is_primitive and field.field_type !=
file_field.field_type:
- return
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type),
include_field_ids=False))
+ if field.field_type.is_primitive:
+ if field.field_type != file_field.field_type:
+ return
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type),
include_field_ids=False))
+ elif (target_type := schema_to_pyarrow(field.field_type,
include_field_ids=False)) != values.type:
+ # if file_field and field_type (e.g. String) are the same
+ # but the pyarrow type of the array is different from the
expected type
+ # (e.g. string vs larger_string), we want to cast the array to
the larger type
+ return values.cast(target_type)
return values
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) ->
pa.Field:
@@ -1207,13 +1243,13 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
return field_array
def list(self, list_type: ListType, list_array: Optional[pa.Array],
value_array: Optional[pa.Array]) -> Optional[pa.Array]:
- if isinstance(list_array, pa.ListArray) and value_array is not None:
+ if isinstance(list_array, (pa.ListArray, pa.LargeListArray,
pa.FixedSizeListArray)) and value_array is not None:
if isinstance(value_array, pa.StructArray):
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
- list_array = pa.ListArray.from_arrays(list_array.offsets,
value_array)
+ list_array = pa.LargeListArray.from_arrays(list_array.offsets,
value_array)
- arrow_field =
pa.list_(self._construct_field(list_type.element_field, value_array.type))
+ arrow_field =
pa.large_list(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
return None
@@ -1263,7 +1299,7 @@ class ArrowAccessor(PartnerAccessor[pa.Array]):
return None
def list_element_partner(self, partner_list: Optional[pa.Array]) ->
Optional[pa.Array]:
- return partner_list.values if isinstance(partner_list, pa.ListArray)
else None
+ return partner_list.values if isinstance(partner_list, (pa.ListArray,
pa.LargeListArray, pa.FixedSizeListArray)) else None
def map_key_partner(self, partner_map: Optional[pa.Array]) ->
Optional[pa.Array]:
return partner_map.keys if isinstance(partner_map, pa.MapArray) else
None
@@ -1800,10 +1836,10 @@ def write_file(io: FileIO, table_metadata:
TableMetadata, tasks: Iterator[WriteT
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) !=
table_schema:
file_schema = sanitized_schema
- arrow_table = to_requested_schema(requested_schema=file_schema,
file_schema=table_schema, table=arrow_table)
else:
file_schema = table_schema
+ arrow_table = to_requested_schema(requested_schema=file_schema,
file_schema=table_schema, table=arrow_table)
file_path =
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 54591622..24adfb88 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -288,7 +288,7 @@ def test_write_pyarrow_schema(catalog: SqlCatalog,
table_identifier: Identifier)
pa.array([None, "A", "B", "C"]), # 'large' column
],
schema=pa.schema([
- pa.field("foo", pa.string(), nullable=True),
+ pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("large", pa.large_string(), nullable=True),
@@ -1325,7 +1325,7 @@ def test_write_and_evolve(catalog: SqlCatalog,
format_version: int) -> None:
{
"foo": ["a", None, "z"],
},
- schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+ schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
)
tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema,
properties={"format-version": str(format_version)})
@@ -1336,7 +1336,7 @@ def test_write_and_evolve(catalog: SqlCatalog,
format_version: int) -> None:
"bar": [19, None, 25],
},
schema=pa.schema([
- pa.field("foo", pa.string(), nullable=True),
+ pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
]),
)
@@ -1375,7 +1375,7 @@ def test_create_table_transaction(catalog: SqlCatalog,
format_version: int) -> N
{
"foo": ["a", None, "z"],
},
- schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+ schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
)
pa_table_with_column = pa.Table.from_pydict(
@@ -1384,7 +1384,7 @@ def test_create_table_transaction(catalog: SqlCatalog,
format_version: int) -> N
"bar": [19, None, 25],
},
schema=pa.schema([
- pa.field("foo", pa.string(), nullable=True),
+ pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
]),
)
diff --git a/tests/conftest.py b/tests/conftest.py
index d3f23689..a160322e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2116,8 +2116,8 @@ def pa_schema() -> "pa.Schema":
return pa.schema([
("bool", pa.bool_()),
- ("string", pa.string()),
- ("string_long", pa.string()),
+ ("string", pa.large_string()),
+ ("string_long", pa.large_string()),
("int", pa.int32()),
("long", pa.int64()),
("float", pa.float32()),
diff --git a/tests/integration/test_writes/test_writes.py
b/tests/integration/test_writes/test_writes.py
index e329adcd..4585406c 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -340,6 +340,60 @@ def
test_python_writes_dictionary_encoded_column_with_spark_reads(
assert spark_df.equals(pyiceberg_df)
[email protected]
[email protected]("format_version", [1, 2])
+def test_python_writes_with_small_and_large_types_spark_reads(
+ spark: SparkSession, session_catalog: Catalog, format_version: int
+) -> None:
+ identifier = "default.python_writes_with_small_and_large_types_spark_reads"
+ TEST_DATA = {
+ "foo": ["a", None, "z"],
+ "id": [1, 2, 3],
+ "name": ["AB", "CD", "EF"],
+ "address": [
+ {"street": "123", "city": "SFO", "zip": 12345, "bar": "a"},
+ {"street": "456", "city": "SW", "zip": 67890, "bar": "b"},
+ {"street": "789", "city": "Random", "zip": 10112, "bar": "c"},
+ ],
+ }
+ pa_schema = pa.schema([
+ pa.field("foo", pa.large_string()),
+ pa.field("id", pa.int32()),
+ pa.field("name", pa.string()),
+ pa.field(
+ "address",
+ pa.struct([
+ pa.field("street", pa.string()),
+ pa.field("city", pa.string()),
+ pa.field("zip", pa.int32()),
+ pa.field("bar", pa.large_string()),
+ ]),
+ ),
+ ])
+ arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema)
+ tbl = _create_table(session_catalog, identifier, {"format-version":
format_version}, schema=pa_schema)
+
+ tbl.overwrite(arrow_table)
+ spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
+ pyiceberg_df = tbl.scan().to_pandas()
+ assert spark_df.equals(pyiceberg_df)
+ arrow_table_on_read = tbl.scan().to_arrow()
+ assert arrow_table_on_read.schema == pa.schema([
+ pa.field("foo", pa.large_string()),
+ pa.field("id", pa.int32()),
+ pa.field("name", pa.large_string()),
+ pa.field(
+ "address",
+ pa.struct([
+ pa.field("street", pa.large_string()),
+ pa.field("city", pa.large_string()),
+ pa.field("zip", pa.int32()),
+ pa.field("bar", pa.large_string()),
+ ]),
+ ),
+ ])
+
+
@pytest.mark.integration
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog:
Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.write_bin_pack_data_files"
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index baa9e308..ecb946a9 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -346,7 +346,7 @@ def test_deleting_hdfs_file_not_found() -> None:
def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested:
Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested)
- expected = """foo: string
+ expected = """foo: large_string
-- field metadata --
PARQUET:field_id: '1'
bar: int32 not null
@@ -355,20 +355,20 @@ bar: int32 not null
baz: bool
-- field metadata --
PARQUET:field_id: '3'
-qux: list<element: string not null> not null
- child 0, element: string not null
+qux: large_list<element: large_string not null> not null
+ child 0, element: large_string not null
-- field metadata --
PARQUET:field_id: '5'
-- field metadata --
PARQUET:field_id: '4'
-quux: map<string, map<string, int32>> not null
- child 0, entries: struct<key: string not null, value: map<string, int32> not
null> not null
- child 0, key: string not null
+quux: map<large_string, map<large_string, int32>> not null
+ child 0, entries: struct<key: large_string not null, value:
map<large_string, int32> not null> not null
+ child 0, key: large_string not null
-- field metadata --
PARQUET:field_id: '7'
- child 1, value: map<string, int32> not null
- child 0, entries: struct<key: string not null, value: int32 not
null> not null
- child 0, key: string not null
+ child 1, value: map<large_string, int32> not null
+ child 0, entries: struct<key: large_string not null, value: int32
not null> not null
+ child 0, key: large_string not null
-- field metadata --
PARQUET:field_id: '9'
child 1, value: int32 not null
@@ -378,7 +378,7 @@ quux: map<string, map<string, int32>> not null
PARQUET:field_id: '8'
-- field metadata --
PARQUET:field_id: '6'
-location: list<element: struct<latitude: float, longitude: float> not null>
not null
+location: large_list<element: struct<latitude: float, longitude: float> not
null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
-- field metadata --
@@ -390,8 +390,8 @@ location: list<element: struct<latitude: float, longitude:
float> not null> not
PARQUET:field_id: '12'
-- field metadata --
PARQUET:field_id: '11'
-person: struct<name: string, age: int32 not null>
- child 0, name: string
+person: struct<name: large_string, age: int32 not null>
+ child 0, name: large_string
-- field metadata --
PARQUET:field_id: '16'
child 1, age: int32 not null
@@ -404,24 +404,24 @@ person: struct<name: string, age: int32 not null>
def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested:
Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
- expected = """foo: string
+ expected = """foo: large_string
bar: int32 not null
baz: bool
-qux: list<element: string not null> not null
- child 0, element: string not null
-quux: map<string, map<string, int32>> not null
- child 0, entries: struct<key: string not null, value: map<string, int32> not
null> not null
- child 0, key: string not null
- child 1, value: map<string, int32> not null
- child 0, entries: struct<key: string not null, value: int32 not
null> not null
- child 0, key: string not null
+qux: large_list<element: large_string not null> not null
+ child 0, element: large_string not null
+quux: map<large_string, map<large_string, int32>> not null
+ child 0, entries: struct<key: large_string not null, value:
map<large_string, int32> not null> not null
+ child 0, key: large_string not null
+ child 1, value: map<large_string, int32> not null
+ child 0, entries: struct<key: large_string not null, value: int32
not null> not null
+ child 0, key: large_string not null
child 1, value: int32 not null
-location: list<element: struct<latitude: float, longitude: float> not null>
not null
+location: large_list<element: struct<latitude: float, longitude: float> not
null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
child 1, longitude: float
-person: struct<name: string, age: int32 not null>
- child 0, name: string
+person: struct<name: large_string, age: int32 not null>
+ child 0, name: large_string
child 1, age: int32 not null"""
assert repr(actual) == expected
@@ -486,7 +486,7 @@ def test_timestamptz_type_to_pyarrow() -> None:
def test_string_type_to_pyarrow() -> None:
iceberg_type = StringType()
- assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.string()
+ assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_string()
def test_binary_type_to_pyarrow() -> None:
@@ -496,7 +496,7 @@ def test_binary_type_to_pyarrow() -> None:
def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None:
expected = pa.struct([
- pa.field("foo", pa.string(), nullable=True, metadata={"field_id":
"1"}),
+ pa.field("foo", pa.large_string(), nullable=True,
metadata={"field_id": "1"}),
pa.field("bar", pa.int32(), nullable=False, metadata={"field_id":
"2"}),
pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}),
])
@@ -513,7 +513,7 @@ def test_map_type_to_pyarrow() -> None:
)
assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_(
pa.field("key", pa.int32(), nullable=False, metadata={"field_id":
"1"}),
- pa.field("value", pa.string(), nullable=False, metadata={"field_id":
"2"}),
+ pa.field("value", pa.large_string(), nullable=False,
metadata={"field_id": "2"}),
)
@@ -523,7 +523,7 @@ def test_list_type_to_pyarrow() -> None:
element_type=IntegerType(),
element_required=True,
)
- assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.list_(
+ assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.large_list(
pa.field("element", pa.int32(), nullable=False, metadata={"field_id":
"1"})
)
@@ -606,11 +606,11 @@ def
test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference[str]
def test_expr_in_to_pyarrow(bound_reference: BoundReference[str]) -> None:
assert repr(expression_to_pyarrow(BoundIn(bound_reference,
{literal("hello"), literal("world")}))) in (
- """<pyarrow.compute.Expression is_in(foo, {value_set=string:[
+ """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[
"hello",
"world"
], null_matching_behavior=MATCH})>""",
- """<pyarrow.compute.Expression is_in(foo, {value_set=string:[
+ """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[
"world",
"hello"
], null_matching_behavior=MATCH})>""",
@@ -619,11 +619,11 @@ def test_expr_in_to_pyarrow(bound_reference:
BoundReference[str]) -> None:
def test_expr_not_in_to_pyarrow(bound_reference: BoundReference[str]) -> None:
assert repr(expression_to_pyarrow(BoundNotIn(bound_reference,
{literal("hello"), literal("world")}))) in (
- """<pyarrow.compute.Expression invert(is_in(foo, {value_set=string:[
+ """<pyarrow.compute.Expression invert(is_in(foo,
{value_set=large_string:[
"hello",
"world"
], null_matching_behavior=MATCH}))>""",
- """<pyarrow.compute.Expression invert(is_in(foo, {value_set=string:[
+ """<pyarrow.compute.Expression invert(is_in(foo,
{value_set=large_string:[
"world",
"hello"
], null_matching_behavior=MATCH}))>""",
@@ -967,12 +967,12 @@ def test_projection_add_column(file_int: str) -> None:
assert (
repr(result_table.schema)
== """id: int32
-list: list<element: int32>
+list: large_list<element: int32>
child 0, element: int32
-map: map<int32, string>
- child 0, entries: struct<key: int32 not null, value: string> not null
+map: map<int32, large_string>
+ child 0, entries: struct<key: int32 not null, value: large_string> not null
child 0, key: int32 not null
- child 1, value: string
+ child 1, value: large_string
location: struct<lat: double, lon: double>
child 0, lat: double
child 1, lon: double"""
@@ -988,7 +988,7 @@ def test_read_list(schema_list: Schema, file_list: str) ->
None:
assert (
repr(result_table.schema)
- == """ids: list<element: int32>
+ == """ids: large_list<element: int32>
child 0, element: int32"""
)
@@ -1002,10 +1002,10 @@ def test_read_map(schema_map: Schema, file_map: str) ->
None:
assert (
repr(result_table.schema)
- == """properties: map<string, string>
- child 0, entries: struct<key: string not null, value: string not null> not
null
- child 0, key: string not null
- child 1, value: string not null"""
+ == """properties: map<large_string, large_string>
+ child 0, entries: struct<key: large_string not null, value: large_string not
null> not null
+ child 0, key: large_string not null
+ child 1, value: large_string not null"""
)
@@ -1025,10 +1025,10 @@ def test_projection_add_column_struct(schema_int:
Schema, file_int: str) -> None
assert r.as_py() is None
assert (
repr(result_table.schema)
- == """id: map<int32, string>
- child 0, entries: struct<key: int32 not null, value: string> not null
+ == """id: map<int32, large_string>
+ child 0, entries: struct<key: int32 not null, value: large_string> not null
child 0, key: int32 not null
- child 1, value: string"""
+ child 1, value: large_string"""
)
@@ -1231,7 +1231,7 @@ def
test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of
]
assert (
repr(result_table.schema)
- == """locations: list<element: struct<latitude: double not null,
longitude: double not null, altitude: double>>
+ == """locations: large_list<element: struct<latitude: double not null,
longitude: double not null, altitude: double>>
child 0, element: struct<latitude: double not null, longitude: double not
null, altitude: double>
child 0, latitude: double not null
child 1, longitude: double not null
@@ -1279,9 +1279,9 @@ def
test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_s
assert actual.as_py() == expected
assert (
repr(result_table.schema)
- == """locations: map<string, struct<latitude: double not null,
longitude: double not null, altitude: double>>
- child 0, entries: struct<key: string not null, value: struct<latitude:
double not null, longitude: double not null, altitude: double> not null> not
null
- child 0, key: string not null
+ == """locations: map<large_string, struct<latitude: double not null,
longitude: double not null, altitude: double>>
+ child 0, entries: struct<key: large_string not null, value: struct<latitude:
double not null, longitude: double not null, altitude: double> not null> not
null
+ child 0, key: large_string not null
child 1, value: struct<latitude: double not null, longitude: double not
null, altitude: double> not null
child 0, latitude: double not null
child 1, longitude: double not null
@@ -1378,7 +1378,7 @@ def test_delete(deletes_file: str, example_task:
FileScanTask, table_schema_simp
assert (
str(with_deletes)
== """pyarrow.Table
-foo: string
+foo: large_string
bar: int32 not null
baz: bool
----
@@ -1416,7 +1416,7 @@ def test_delete_duplicates(deletes_file: str,
example_task: FileScanTask, table_
assert (
str(with_deletes)
== """pyarrow.Table
-foo: string
+foo: large_string
bar: int32 not null
baz: bool
----
@@ -1447,7 +1447,7 @@ def test_pyarrow_wrap_fsspec(example_task: FileScanTask,
table_schema_simple: Sc
assert (
str(projection)
== """pyarrow.Table
-foo: string
+foo: large_string
bar: int32 not null
baz: bool
----
diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py
index c8571dac..d3b6217c 100644
--- a/tests/io/test_pyarrow_visitor.py
+++ b/tests/io/test_pyarrow_visitor.py
@@ -25,6 +25,7 @@ from pyiceberg.io.pyarrow import (
_ConvertToIceberg,
_ConvertToIcebergWithoutIDs,
_HasIds,
+ _pyarrow_schema_ensure_large_types,
pyarrow_to_schema,
schema_to_pyarrow,
visit_pyarrow,
@@ -209,7 +210,7 @@ def test_pyarrow_timestamp_tz_invalid_tz() -> None:
def test_pyarrow_string_to_iceberg() -> None:
- pyarrow_type = pa.string()
+ pyarrow_type = pa.large_string()
converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg())
assert converted_iceberg_type == StringType()
assert visit(converted_iceberg_type, _ConvertToArrowSchema()) ==
pyarrow_type
@@ -543,3 +544,39 @@ def test_pyarrow_schema_to_schema_fresh_ids_nested_schema(
pyarrow_schema_nested_without_ids: pa.Schema,
iceberg_schema_nested_no_ids: Schema
) -> None:
assert visit_pyarrow(pyarrow_schema_nested_without_ids,
_ConvertToIcebergWithoutIDs()) == iceberg_schema_nested_no_ids
+
+
+def test_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids:
pa.Schema) -> None:
+ expected_schema = pa.schema([
+ pa.field("foo", pa.large_string(), nullable=False),
+ pa.field("bar", pa.int32(), nullable=False),
+ pa.field("baz", pa.bool_(), nullable=True),
+ pa.field("qux", pa.large_list(pa.large_string()), nullable=False),
+ pa.field(
+ "quux",
+ pa.map_(
+ pa.large_string(),
+ pa.map_(pa.large_string(), pa.int32()),
+ ),
+ nullable=False,
+ ),
+ pa.field(
+ "location",
+ pa.large_list(
+ pa.struct([
+ pa.field("latitude", pa.float32(), nullable=False),
+ pa.field("longitude", pa.float32(), nullable=False),
+ ]),
+ ),
+ nullable=False,
+ ),
+ pa.field(
+ "person",
+ pa.struct([
+ pa.field("name", pa.large_string(), nullable=True),
+ pa.field("age", pa.int32(), nullable=False),
+ ]),
+ nullable=True,
+ ),
+ ])
+ assert
_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids) ==
expected_schema
diff --git a/tests/test_schema.py b/tests/test_schema.py
index 96109ce9..23b42ef4 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -1610,7 +1610,7 @@ def test_arrow_schema() -> None:
)
expected_schema = pa.schema([
- pa.field("foo", pa.string(), nullable=False),
+ pa.field("foo", pa.large_string(), nullable=False),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
])