This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new d4a4eede Cast PyArrow schema to `large_*` types (#807)
d4a4eede is described below

commit d4a4eedee247c6a14f31383b0e43a91169d28539
Author: Sung Yun <[email protected]>
AuthorDate: Fri Jun 14 16:13:19 2024 -0400

    Cast PyArrow schema to `large_*` types (#807)
    
    * _pyarrow_with
    
    * fix
    
    * fix test
    
    * adopt review feedback
    
    * revert accidental conf change
    
    * adopt-nit
---
 pyiceberg/io/pyarrow.py                      |  56 ++++++++++++---
 tests/catalog/test_sql.py                    |  10 +--
 tests/conftest.py                            |   4 +-
 tests/integration/test_writes/test_writes.py |  54 ++++++++++++++
 tests/io/test_pyarrow.py                     | 102 +++++++++++++--------------
 tests/io/test_pyarrow_visitor.py             |  39 +++++++++-
 tests/test_schema.py                         |   2 +-
 7 files changed, 197 insertions(+), 70 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 71925c27..935b78ce 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -504,7 +504,7 @@ class 
_ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
 
     def list(self, list_type: ListType, element_result: pa.DataType) -> 
pa.DataType:
         element_field = self.field(list_type.element_field, element_result)
-        return pa.list_(value_type=element_field)
+        return pa.large_list(value_type=element_field)
 
     def map(self, map_type: MapType, key_result: pa.DataType, value_result: 
pa.DataType) -> pa.DataType:
         key_field = self.field(map_type.key_field, key_result)
@@ -548,7 +548,7 @@ class 
_ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
         return pa.timestamp(unit="us", tz="UTC")
 
     def visit_string(self, _: StringType) -> pa.DataType:
-        return pa.string()
+        return pa.large_string()
 
     def visit_uuid(self, _: UUIDType) -> pa.DataType:
         return pa.binary(16)
@@ -680,6 +680,10 @@ def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> 
Schema:
     return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
 
 
+def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
+    return visit_pyarrow(schema, _ConvertToLargeTypes())
+
+
 @singledispatch
 def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: 
PyArrowSchemaVisitor[T]) -> T:
     """Apply a pyarrow schema visitor to any point within a schema.
@@ -952,6 +956,30 @@ class 
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
         self._field_names.pop()
 
 
+class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, 
pa.Schema]]):
+    def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> 
pa.Schema:
+        return pa.schema(struct_result)
+
+    def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> 
pa.StructType:
+        return pa.struct(field_results)
+
+    def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
+        return field.with_type(field_result)
+
+    def list(self, list_type: pa.ListType, element_result: pa.DataType) -> 
pa.DataType:
+        return pa.large_list(element_result)
+
+    def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: 
pa.DataType) -> pa.DataType:
+        return pa.map_(key_result, value_result)
+
+    def primitive(self, primitive: pa.DataType) -> pa.DataType:
+        if primitive == pa.string():
+            return pa.large_string()
+        elif primitive == pa.binary():
+            return pa.large_binary()
+        return primitive
+
+
 class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
     """
     Converts PyArrowSchema to Iceberg Schema with all -1 ids.
@@ -998,7 +1026,9 @@ def _task_to_table(
 
         fragment_scanner = ds.Scanner.from_fragment(
             fragment=fragment,
-            schema=physical_schema,
+            # We always use large types in memory as it uses larger offsets
+            # That can chunk more row values into the buffers
+            schema=_pyarrow_schema_ensure_large_types(physical_schema),
             # This will push down the query to Arrow.
             # But in case there are positional deletes, we have to apply them 
first
             filter=pyarrow_filter if not positional_deletes else None,
@@ -1167,8 +1197,14 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
 
     def _cast_if_needed(self, field: NestedField, values: pa.Array) -> 
pa.Array:
         file_field = self.file_schema.find_field(field.field_id)
-        if field.field_type.is_primitive and field.field_type != 
file_field.field_type:
-            return 
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), 
include_field_ids=False))
+        if field.field_type.is_primitive:
+            if field.field_type != file_field.field_type:
+                return 
values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), 
include_field_ids=False))
+            elif (target_type := schema_to_pyarrow(field.field_type, 
include_field_ids=False)) != values.type:
+                # if file_field and field_type  (e.g. String) are the same
+                # but the pyarrow type of the array is different from the 
expected type
+                # (e.g. string vs larger_string), we want to cast the array to 
the larger type
+                return values.cast(target_type)
         return values
 
     def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> 
pa.Field:
@@ -1207,13 +1243,13 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
         return field_array
 
     def list(self, list_type: ListType, list_array: Optional[pa.Array], 
value_array: Optional[pa.Array]) -> Optional[pa.Array]:
-        if isinstance(list_array, pa.ListArray) and value_array is not None:
+        if isinstance(list_array, (pa.ListArray, pa.LargeListArray, 
pa.FixedSizeListArray)) and value_array is not None:
             if isinstance(value_array, pa.StructArray):
                 # This can be removed once this has been fixed:
                 # https://github.com/apache/arrow/issues/38809
-                list_array = pa.ListArray.from_arrays(list_array.offsets, 
value_array)
+                list_array = pa.LargeListArray.from_arrays(list_array.offsets, 
value_array)
 
-            arrow_field = 
pa.list_(self._construct_field(list_type.element_field, value_array.type))
+            arrow_field = 
pa.large_list(self._construct_field(list_type.element_field, value_array.type))
             return list_array.cast(arrow_field)
         else:
             return None
@@ -1263,7 +1299,7 @@ class ArrowAccessor(PartnerAccessor[pa.Array]):
         return None
 
     def list_element_partner(self, partner_list: Optional[pa.Array]) -> 
Optional[pa.Array]:
-        return partner_list.values if isinstance(partner_list, pa.ListArray) 
else None
+        return partner_list.values if isinstance(partner_list, (pa.ListArray, 
pa.LargeListArray, pa.FixedSizeListArray)) else None
 
     def map_key_partner(self, partner_map: Optional[pa.Array]) -> 
Optional[pa.Array]:
         return partner_map.keys if isinstance(partner_map, pa.MapArray) else 
None
@@ -1800,10 +1836,10 @@ def write_file(io: FileIO, table_metadata: 
TableMetadata, tasks: Iterator[WriteT
         # otherwise use the original schema
         if (sanitized_schema := sanitize_column_names(table_schema)) != 
table_schema:
             file_schema = sanitized_schema
-            arrow_table = to_requested_schema(requested_schema=file_schema, 
file_schema=table_schema, table=arrow_table)
         else:
             file_schema = table_schema
 
+        arrow_table = to_requested_schema(requested_schema=file_schema, 
file_schema=table_schema, table=arrow_table)
         file_path = 
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
         fo = io.new_output(file_path)
         with fo.create(overwrite=True) as fos:
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 54591622..24adfb88 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -288,7 +288,7 @@ def test_write_pyarrow_schema(catalog: SqlCatalog, 
table_identifier: Identifier)
             pa.array([None, "A", "B", "C"]),  # 'large' column
         ],
         schema=pa.schema([
-            pa.field("foo", pa.string(), nullable=True),
+            pa.field("foo", pa.large_string(), nullable=True),
             pa.field("bar", pa.int32(), nullable=False),
             pa.field("baz", pa.bool_(), nullable=True),
             pa.field("large", pa.large_string(), nullable=True),
@@ -1325,7 +1325,7 @@ def test_write_and_evolve(catalog: SqlCatalog, 
format_version: int) -> None:
         {
             "foo": ["a", None, "z"],
         },
-        schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+        schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
     )
 
     tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema, 
properties={"format-version": str(format_version)})
@@ -1336,7 +1336,7 @@ def test_write_and_evolve(catalog: SqlCatalog, 
format_version: int) -> None:
             "bar": [19, None, 25],
         },
         schema=pa.schema([
-            pa.field("foo", pa.string(), nullable=True),
+            pa.field("foo", pa.large_string(), nullable=True),
             pa.field("bar", pa.int32(), nullable=True),
         ]),
     )
@@ -1375,7 +1375,7 @@ def test_create_table_transaction(catalog: SqlCatalog, 
format_version: int) -> N
         {
             "foo": ["a", None, "z"],
         },
-        schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+        schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
     )
 
     pa_table_with_column = pa.Table.from_pydict(
@@ -1384,7 +1384,7 @@ def test_create_table_transaction(catalog: SqlCatalog, 
format_version: int) -> N
             "bar": [19, None, 25],
         },
         schema=pa.schema([
-            pa.field("foo", pa.string(), nullable=True),
+            pa.field("foo", pa.large_string(), nullable=True),
             pa.field("bar", pa.int32(), nullable=True),
         ]),
     )
diff --git a/tests/conftest.py b/tests/conftest.py
index d3f23689..a160322e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2116,8 +2116,8 @@ def pa_schema() -> "pa.Schema":
 
     return pa.schema([
         ("bool", pa.bool_()),
-        ("string", pa.string()),
-        ("string_long", pa.string()),
+        ("string", pa.large_string()),
+        ("string_long", pa.large_string()),
         ("int", pa.int32()),
         ("long", pa.int64()),
         ("float", pa.float32()),
diff --git a/tests/integration/test_writes/test_writes.py 
b/tests/integration/test_writes/test_writes.py
index e329adcd..4585406c 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -340,6 +340,60 @@ def 
test_python_writes_dictionary_encoded_column_with_spark_reads(
     assert spark_df.equals(pyiceberg_df)
 
 
[email protected]
[email protected]("format_version", [1, 2])
+def test_python_writes_with_small_and_large_types_spark_reads(
+    spark: SparkSession, session_catalog: Catalog, format_version: int
+) -> None:
+    identifier = "default.python_writes_with_small_and_large_types_spark_reads"
+    TEST_DATA = {
+        "foo": ["a", None, "z"],
+        "id": [1, 2, 3],
+        "name": ["AB", "CD", "EF"],
+        "address": [
+            {"street": "123", "city": "SFO", "zip": 12345, "bar": "a"},
+            {"street": "456", "city": "SW", "zip": 67890, "bar": "b"},
+            {"street": "789", "city": "Random", "zip": 10112, "bar": "c"},
+        ],
+    }
+    pa_schema = pa.schema([
+        pa.field("foo", pa.large_string()),
+        pa.field("id", pa.int32()),
+        pa.field("name", pa.string()),
+        pa.field(
+            "address",
+            pa.struct([
+                pa.field("street", pa.string()),
+                pa.field("city", pa.string()),
+                pa.field("zip", pa.int32()),
+                pa.field("bar", pa.large_string()),
+            ]),
+        ),
+    ])
+    arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema)
+    tbl = _create_table(session_catalog, identifier, {"format-version": 
format_version}, schema=pa_schema)
+
+    tbl.overwrite(arrow_table)
+    spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
+    pyiceberg_df = tbl.scan().to_pandas()
+    assert spark_df.equals(pyiceberg_df)
+    arrow_table_on_read = tbl.scan().to_arrow()
+    assert arrow_table_on_read.schema == pa.schema([
+        pa.field("foo", pa.large_string()),
+        pa.field("id", pa.int32()),
+        pa.field("name", pa.large_string()),
+        pa.field(
+            "address",
+            pa.struct([
+                pa.field("street", pa.large_string()),
+                pa.field("city", pa.large_string()),
+                pa.field("zip", pa.int32()),
+                pa.field("bar", pa.large_string()),
+            ]),
+        ),
+    ])
+
+
 @pytest.mark.integration
 def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: 
Catalog, arrow_table_with_null: pa.Table) -> None:
     identifier = "default.write_bin_pack_data_files"
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index baa9e308..ecb946a9 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -346,7 +346,7 @@ def test_deleting_hdfs_file_not_found() -> None:
 
 def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: 
Schema) -> None:
     actual = schema_to_pyarrow(table_schema_nested)
-    expected = """foo: string
+    expected = """foo: large_string
   -- field metadata --
   PARQUET:field_id: '1'
 bar: int32 not null
@@ -355,20 +355,20 @@ bar: int32 not null
 baz: bool
   -- field metadata --
   PARQUET:field_id: '3'
-qux: list<element: string not null> not null
-  child 0, element: string not null
+qux: large_list<element: large_string not null> not null
+  child 0, element: large_string not null
     -- field metadata --
     PARQUET:field_id: '5'
   -- field metadata --
   PARQUET:field_id: '4'
-quux: map<string, map<string, int32>> not null
-  child 0, entries: struct<key: string not null, value: map<string, int32> not 
null> not null
-      child 0, key: string not null
+quux: map<large_string, map<large_string, int32>> not null
+  child 0, entries: struct<key: large_string not null, value: 
map<large_string, int32> not null> not null
+      child 0, key: large_string not null
       -- field metadata --
       PARQUET:field_id: '7'
-      child 1, value: map<string, int32> not null
-          child 0, entries: struct<key: string not null, value: int32 not 
null> not null
-              child 0, key: string not null
+      child 1, value: map<large_string, int32> not null
+          child 0, entries: struct<key: large_string not null, value: int32 
not null> not null
+              child 0, key: large_string not null
           -- field metadata --
           PARQUET:field_id: '9'
               child 1, value: int32 not null
@@ -378,7 +378,7 @@ quux: map<string, map<string, int32>> not null
       PARQUET:field_id: '8'
   -- field metadata --
   PARQUET:field_id: '6'
-location: list<element: struct<latitude: float, longitude: float> not null> 
not null
+location: large_list<element: struct<latitude: float, longitude: float> not 
null> not null
   child 0, element: struct<latitude: float, longitude: float> not null
       child 0, latitude: float
       -- field metadata --
@@ -390,8 +390,8 @@ location: list<element: struct<latitude: float, longitude: 
float> not null> not
     PARQUET:field_id: '12'
   -- field metadata --
   PARQUET:field_id: '11'
-person: struct<name: string, age: int32 not null>
-  child 0, name: string
+person: struct<name: large_string, age: int32 not null>
+  child 0, name: large_string
     -- field metadata --
     PARQUET:field_id: '16'
   child 1, age: int32 not null
@@ -404,24 +404,24 @@ person: struct<name: string, age: int32 not null>
 
 def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: 
Schema) -> None:
     actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
-    expected = """foo: string
+    expected = """foo: large_string
 bar: int32 not null
 baz: bool
-qux: list<element: string not null> not null
-  child 0, element: string not null
-quux: map<string, map<string, int32>> not null
-  child 0, entries: struct<key: string not null, value: map<string, int32> not 
null> not null
-      child 0, key: string not null
-      child 1, value: map<string, int32> not null
-          child 0, entries: struct<key: string not null, value: int32 not 
null> not null
-              child 0, key: string not null
+qux: large_list<element: large_string not null> not null
+  child 0, element: large_string not null
+quux: map<large_string, map<large_string, int32>> not null
+  child 0, entries: struct<key: large_string not null, value: 
map<large_string, int32> not null> not null
+      child 0, key: large_string not null
+      child 1, value: map<large_string, int32> not null
+          child 0, entries: struct<key: large_string not null, value: int32 
not null> not null
+              child 0, key: large_string not null
               child 1, value: int32 not null
-location: list<element: struct<latitude: float, longitude: float> not null> 
not null
+location: large_list<element: struct<latitude: float, longitude: float> not 
null> not null
   child 0, element: struct<latitude: float, longitude: float> not null
       child 0, latitude: float
       child 1, longitude: float
-person: struct<name: string, age: int32 not null>
-  child 0, name: string
+person: struct<name: large_string, age: int32 not null>
+  child 0, name: large_string
   child 1, age: int32 not null"""
     assert repr(actual) == expected
 
@@ -486,7 +486,7 @@ def test_timestamptz_type_to_pyarrow() -> None:
 
 def test_string_type_to_pyarrow() -> None:
     iceberg_type = StringType()
-    assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.string()
+    assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_string()
 
 
 def test_binary_type_to_pyarrow() -> None:
@@ -496,7 +496,7 @@ def test_binary_type_to_pyarrow() -> None:
 
 def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None:
     expected = pa.struct([
-        pa.field("foo", pa.string(), nullable=True, metadata={"field_id": 
"1"}),
+        pa.field("foo", pa.large_string(), nullable=True, 
metadata={"field_id": "1"}),
         pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": 
"2"}),
         pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}),
     ])
@@ -513,7 +513,7 @@ def test_map_type_to_pyarrow() -> None:
     )
     assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_(
         pa.field("key", pa.int32(), nullable=False, metadata={"field_id": 
"1"}),
-        pa.field("value", pa.string(), nullable=False, metadata={"field_id": 
"2"}),
+        pa.field("value", pa.large_string(), nullable=False, 
metadata={"field_id": "2"}),
     )
 
 
@@ -523,7 +523,7 @@ def test_list_type_to_pyarrow() -> None:
         element_type=IntegerType(),
         element_required=True,
     )
-    assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.list_(
+    assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.large_list(
         pa.field("element", pa.int32(), nullable=False, metadata={"field_id": 
"1"})
     )
 
@@ -606,11 +606,11 @@ def 
test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference[str]
 
 def test_expr_in_to_pyarrow(bound_reference: BoundReference[str]) -> None:
     assert repr(expression_to_pyarrow(BoundIn(bound_reference, 
{literal("hello"), literal("world")}))) in (
-        """<pyarrow.compute.Expression is_in(foo, {value_set=string:[
+        """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[
   "hello",
   "world"
 ], null_matching_behavior=MATCH})>""",
-        """<pyarrow.compute.Expression is_in(foo, {value_set=string:[
+        """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[
   "world",
   "hello"
 ], null_matching_behavior=MATCH})>""",
@@ -619,11 +619,11 @@ def test_expr_in_to_pyarrow(bound_reference: 
BoundReference[str]) -> None:
 
 def test_expr_not_in_to_pyarrow(bound_reference: BoundReference[str]) -> None:
     assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, 
{literal("hello"), literal("world")}))) in (
-        """<pyarrow.compute.Expression invert(is_in(foo, {value_set=string:[
+        """<pyarrow.compute.Expression invert(is_in(foo, 
{value_set=large_string:[
   "hello",
   "world"
 ], null_matching_behavior=MATCH}))>""",
-        """<pyarrow.compute.Expression invert(is_in(foo, {value_set=string:[
+        """<pyarrow.compute.Expression invert(is_in(foo, 
{value_set=large_string:[
   "world",
   "hello"
 ], null_matching_behavior=MATCH}))>""",
@@ -967,12 +967,12 @@ def test_projection_add_column(file_int: str) -> None:
     assert (
         repr(result_table.schema)
         == """id: int32
-list: list<element: int32>
+list: large_list<element: int32>
   child 0, element: int32
-map: map<int32, string>
-  child 0, entries: struct<key: int32 not null, value: string> not null
+map: map<int32, large_string>
+  child 0, entries: struct<key: int32 not null, value: large_string> not null
       child 0, key: int32 not null
-      child 1, value: string
+      child 1, value: large_string
 location: struct<lat: double, lon: double>
   child 0, lat: double
   child 1, lon: double"""
@@ -988,7 +988,7 @@ def test_read_list(schema_list: Schema, file_list: str) -> 
None:
 
     assert (
         repr(result_table.schema)
-        == """ids: list<element: int32>
+        == """ids: large_list<element: int32>
   child 0, element: int32"""
     )
 
@@ -1002,10 +1002,10 @@ def test_read_map(schema_map: Schema, file_map: str) -> 
None:
 
     assert (
         repr(result_table.schema)
-        == """properties: map<string, string>
-  child 0, entries: struct<key: string not null, value: string not null> not 
null
-      child 0, key: string not null
-      child 1, value: string not null"""
+        == """properties: map<large_string, large_string>
+  child 0, entries: struct<key: large_string not null, value: large_string not 
null> not null
+      child 0, key: large_string not null
+      child 1, value: large_string not null"""
     )
 
 
@@ -1025,10 +1025,10 @@ def test_projection_add_column_struct(schema_int: 
Schema, file_int: str) -> None
         assert r.as_py() is None
     assert (
         repr(result_table.schema)
-        == """id: map<int32, string>
-  child 0, entries: struct<key: int32 not null, value: string> not null
+        == """id: map<int32, large_string>
+  child 0, entries: struct<key: int32 not null, value: large_string> not null
       child 0, key: int32 not null
-      child 1, value: string"""
+      child 1, value: large_string"""
     )
 
 
@@ -1231,7 +1231,7 @@ def 
test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of
     ]
     assert (
         repr(result_table.schema)
-        == """locations: list<element: struct<latitude: double not null, 
longitude: double not null, altitude: double>>
+        == """locations: large_list<element: struct<latitude: double not null, 
longitude: double not null, altitude: double>>
   child 0, element: struct<latitude: double not null, longitude: double not 
null, altitude: double>
       child 0, latitude: double not null
       child 1, longitude: double not null
@@ -1279,9 +1279,9 @@ def 
test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_s
         assert actual.as_py() == expected
     assert (
         repr(result_table.schema)
-        == """locations: map<string, struct<latitude: double not null, 
longitude: double not null, altitude: double>>
-  child 0, entries: struct<key: string not null, value: struct<latitude: 
double not null, longitude: double not null, altitude: double> not null> not 
null
-      child 0, key: string not null
+        == """locations: map<large_string, struct<latitude: double not null, 
longitude: double not null, altitude: double>>
+  child 0, entries: struct<key: large_string not null, value: struct<latitude: 
double not null, longitude: double not null, altitude: double> not null> not 
null
+      child 0, key: large_string not null
       child 1, value: struct<latitude: double not null, longitude: double not 
null, altitude: double> not null
           child 0, latitude: double not null
           child 1, longitude: double not null
@@ -1378,7 +1378,7 @@ def test_delete(deletes_file: str, example_task: 
FileScanTask, table_schema_simp
     assert (
         str(with_deletes)
         == """pyarrow.Table
-foo: string
+foo: large_string
 bar: int32 not null
 baz: bool
 ----
@@ -1416,7 +1416,7 @@ def test_delete_duplicates(deletes_file: str, 
example_task: FileScanTask, table_
     assert (
         str(with_deletes)
         == """pyarrow.Table
-foo: string
+foo: large_string
 bar: int32 not null
 baz: bool
 ----
@@ -1447,7 +1447,7 @@ def test_pyarrow_wrap_fsspec(example_task: FileScanTask, 
table_schema_simple: Sc
     assert (
         str(projection)
         == """pyarrow.Table
-foo: string
+foo: large_string
 bar: int32 not null
 baz: bool
 ----
diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py
index c8571dac..d3b6217c 100644
--- a/tests/io/test_pyarrow_visitor.py
+++ b/tests/io/test_pyarrow_visitor.py
@@ -25,6 +25,7 @@ from pyiceberg.io.pyarrow import (
     _ConvertToIceberg,
     _ConvertToIcebergWithoutIDs,
     _HasIds,
+    _pyarrow_schema_ensure_large_types,
     pyarrow_to_schema,
     schema_to_pyarrow,
     visit_pyarrow,
@@ -209,7 +210,7 @@ def test_pyarrow_timestamp_tz_invalid_tz() -> None:
 
 
 def test_pyarrow_string_to_iceberg() -> None:
-    pyarrow_type = pa.string()
+    pyarrow_type = pa.large_string()
     converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg())
     assert converted_iceberg_type == StringType()
     assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == 
pyarrow_type
@@ -543,3 +544,39 @@ def test_pyarrow_schema_to_schema_fresh_ids_nested_schema(
     pyarrow_schema_nested_without_ids: pa.Schema, 
iceberg_schema_nested_no_ids: Schema
 ) -> None:
     assert visit_pyarrow(pyarrow_schema_nested_without_ids, 
_ConvertToIcebergWithoutIDs()) == iceberg_schema_nested_no_ids
+
+
+def test_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids: 
pa.Schema) -> None:
+    expected_schema = pa.schema([
+        pa.field("foo", pa.large_string(), nullable=False),
+        pa.field("bar", pa.int32(), nullable=False),
+        pa.field("baz", pa.bool_(), nullable=True),
+        pa.field("qux", pa.large_list(pa.large_string()), nullable=False),
+        pa.field(
+            "quux",
+            pa.map_(
+                pa.large_string(),
+                pa.map_(pa.large_string(), pa.int32()),
+            ),
+            nullable=False,
+        ),
+        pa.field(
+            "location",
+            pa.large_list(
+                pa.struct([
+                    pa.field("latitude", pa.float32(), nullable=False),
+                    pa.field("longitude", pa.float32(), nullable=False),
+                ]),
+            ),
+            nullable=False,
+        ),
+        pa.field(
+            "person",
+            pa.struct([
+                pa.field("name", pa.large_string(), nullable=True),
+                pa.field("age", pa.int32(), nullable=False),
+            ]),
+            nullable=True,
+        ),
+    ])
+    assert 
_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids) == 
expected_schema
diff --git a/tests/test_schema.py b/tests/test_schema.py
index 96109ce9..23b42ef4 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -1610,7 +1610,7 @@ def test_arrow_schema() -> None:
     )
 
     expected_schema = pa.schema([
-        pa.field("foo", pa.string(), nullable=False),
+        pa.field("foo", pa.large_string(), nullable=False),
         pa.field("bar", pa.int32(), nullable=True),
         pa.field("baz", pa.bool_(), nullable=True),
     ])

Reply via email to