This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 02e6430  `create_table` with a PyArrow Schema (#305)
02e6430 is described below

commit 02e64300aee376a76c175be253a29dcd7c31f0cc
Author: Sung Yun <107272191+syu...@users.noreply.github.com>
AuthorDate: Mon Jan 29 21:10:39 2024 -0500

    `create_table` with a PyArrow Schema (#305)
---
 mkdocs/docs/api.md               |  19 +++++
 pyiceberg/catalog/__init__.py    |  22 ++++-
 pyiceberg/catalog/dynamodb.py    |   8 +-
 pyiceberg/catalog/glue.py        |   8 +-
 pyiceberg/catalog/hive.py        |   9 +-
 pyiceberg/catalog/noop.py        |   6 +-
 pyiceberg/catalog/rest.py        |   8 +-
 pyiceberg/catalog/sql.py         |   8 +-
 pyiceberg/io/pyarrow.py          |  25 ++++--
 pyiceberg/schema.py              |  53 ++++++------
 pyproject.toml                   |   2 +-
 tests/catalog/test_base.py       |  34 +++++++-
 tests/catalog/test_dynamodb.py   |  18 ++++
 tests/catalog/test_glue.py       |  23 ++++++
 tests/catalog/test_sql.py        |  20 +++++
 tests/conftest.py                | 173 +++++++++++++++++++++++++++++++++++++++
 tests/io/test_pyarrow_visitor.py | 120 ++++-----------------------
 17 files changed, 417 insertions(+), 139 deletions(-)

diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 6f79835..650d391 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -146,6 +146,25 @@ catalog.create_table(
 )
 ```
 
+To create a table using a pyarrow schema:
+
+```python
+import pyarrow as pa
+
+schema = pa.schema(
+    [
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=False),
+        pa.field("baz", pa.bool_(), nullable=True),
+    ]
+)
+
+catalog.create_table(
+    identifier="docs_example.bids",
+    schema=schema,
+)
+```
+
 ## Load a table
 
 ### Catalog table
diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py
index a39d0e9..6e5dc27 100644
--- a/pyiceberg/catalog/__init__.py
+++ b/pyiceberg/catalog/__init__.py
@@ -24,6 +24,7 @@ from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from enum import Enum
 from typing import (
+    TYPE_CHECKING,
     Callable,
     Dict,
     List,
@@ -56,6 +57,9 @@ from pyiceberg.typedef import (
 )
 from pyiceberg.utils.config import Config, merge_config
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 logger = logging.getLogger(__name__)
 
 _ENV_CONFIG = Config()
@@ -288,7 +292,7 @@ class Catalog(ABC):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -512,6 +516,22 @@ class Catalog(ABC):
             if overlap:
                 raise ValueError(f"Updates and deletes have an overlap: 
{overlap}")
 
+    @staticmethod
+    def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> 
Schema:
+        if isinstance(schema, Schema):
+            return schema
+        try:
+            import pyarrow as pa
+
+            from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, 
visit_pyarrow
+
+            if isinstance(schema, pa.Schema):
+                schema: Schema = visit_pyarrow(schema, 
_ConvertToIcebergWithoutIDs())  # type: ignore
+                return schema
+        except ModuleNotFoundError:
+            pass
+        raise ValueError(f"{type(schema)=}, but it must be 
pyiceberg.schema.Schema or pyarrow.Schema")
+
     def _resolve_table_location(self, location: Optional[str], database_name: 
str, table_name: str) -> str:
         if not location:
             return self._get_default_warehouse_location(database_name, 
table_name)
diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py
index 6c3f931..d5f3b5e 100644
--- a/pyiceberg/catalog/dynamodb.py
+++ b/pyiceberg/catalog/dynamodb.py
@@ -17,6 +17,7 @@
 import uuid
 from time import time
 from typing import (
+    TYPE_CHECKING,
     Any,
     Dict,
     List,
@@ -57,6 +58,9 @@ from pyiceberg.table.metadata import new_table_metadata
 from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
 from pyiceberg.typedef import EMPTY_DICT
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 DYNAMODB_CLIENT = "dynamodb"
 
 DYNAMODB_COL_IDENTIFIER = "identifier"
@@ -127,7 +131,7 @@ class DynamoDbCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -152,6 +156,8 @@ class DynamoDbCatalog(Catalog):
             ValueError: If the identifier is invalid, or no path is given to 
store metadata.
 
         """
+        schema: Schema = self._convert_schema_if_needed(schema)  # type: ignore
+
         database_name, table_name = 
self.identifier_to_database_and_table(identifier)
 
         location = self._resolve_table_location(location, database_name, 
table_name)
diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py
index 645568f..8f860fa 100644
--- a/pyiceberg/catalog/glue.py
+++ b/pyiceberg/catalog/glue.py
@@ -17,6 +17,7 @@
 
 
 from typing import (
+    TYPE_CHECKING,
     Any,
     Dict,
     List,
@@ -88,6 +89,9 @@ from pyiceberg.types import (
     UUIDType,
 )
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 # If Glue should skip archiving an old table version when creating a new 
version in a commit. By
 # default, Glue archives all old table versions after an UpdateTable call, but 
Glue has a default
 # max number of archived table versions (can be increased). So for streaming 
use case with lots
@@ -329,7 +333,7 @@ class GlueCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -354,6 +358,8 @@ class GlueCatalog(Catalog):
             ValueError: If the identifier is invalid, or no path is given to 
store metadata.
 
         """
+        schema: Schema = self._convert_schema_if_needed(schema)  # type: ignore
+
         database_name, table_name = 
self.identifier_to_database_and_table(identifier)
 
         location = self._resolve_table_location(location, database_name, 
table_name)
diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py
index 331b9ca..8069321 100644
--- a/pyiceberg/catalog/hive.py
+++ b/pyiceberg/catalog/hive.py
@@ -18,6 +18,7 @@ import getpass
 import time
 from types import TracebackType
 from typing import (
+    TYPE_CHECKING,
     Any,
     Dict,
     List,
@@ -91,6 +92,10 @@ from pyiceberg.types import (
     UUIDType,
 )
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
+
 # Replace by visitor
 hive_types = {
     BooleanType: "boolean",
@@ -250,7 +255,7 @@ class HiveCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -273,6 +278,8 @@ class HiveCatalog(Catalog):
             AlreadyExistsError: If a table with the name already exists.
             ValueError: If the identifier is invalid.
         """
+        schema: Schema = self._convert_schema_if_needed(schema)  # type: ignore
+
         properties = {**DEFAULT_PROPERTIES, **properties}
         database_name, table_name = 
self.identifier_to_database_and_table(identifier)
         current_time_millis = int(time.time() * 1000)
diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py
index 083f851..a8b7154 100644
--- a/pyiceberg/catalog/noop.py
+++ b/pyiceberg/catalog/noop.py
@@ -15,6 +15,7 @@
 #  specific language governing permissions and limitations
 #  under the License.
 from typing import (
+    TYPE_CHECKING,
     List,
     Optional,
     Set,
@@ -33,12 +34,15 @@ from pyiceberg.table import (
 from pyiceberg.table.sorting import UNSORTED_SORT_ORDER
 from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 
 class NoopCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py
index de192a9..34d75b5 100644
--- a/pyiceberg/catalog/rest.py
+++ b/pyiceberg/catalog/rest.py
@@ -16,6 +16,7 @@
 #  under the License.
 from json import JSONDecodeError
 from typing import (
+    TYPE_CHECKING,
     Any,
     Dict,
     List,
@@ -68,6 +69,9 @@ from pyiceberg.table import (
 from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
 from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 ICEBERG_REST_SPEC_VERSION = "0.14.1"
 
 
@@ -437,12 +441,14 @@ class RestCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
         properties: Properties = EMPTY_DICT,
     ) -> Table:
+        schema: Schema = self._convert_schema_if_needed(schema)  # type: ignore
+
         namespace_and_table = self._split_identifier_for_path(identifier)
         request = CreateTableRequest(
             name=namespace_and_table["table"],
diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py
index 593c6b5..8a02b20 100644
--- a/pyiceberg/catalog/sql.py
+++ b/pyiceberg/catalog/sql.py
@@ -16,6 +16,7 @@
 # under the License.
 
 from typing import (
+    TYPE_CHECKING,
     List,
     Optional,
     Set,
@@ -65,6 +66,9 @@ from pyiceberg.table.metadata import new_table_metadata
 from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
 from pyiceberg.typedef import EMPTY_DICT
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 
 class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase):
     pass
@@ -140,7 +144,7 @@ class SqlCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -165,6 +169,8 @@ class SqlCatalog(Catalog):
             ValueError: If the identifier is invalid, or no path is given to 
store metadata.
 
         """
+        schema: Schema = self._convert_schema_if_needed(schema)  # type: ignore
+
         database_name, table_name = 
self.identifier_to_database_and_table(identifier)
         if not self._namespace_exists(database_name):
             raise NoSuchNamespaceError(f"Namespace does not exist: 
{database_name}")
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 1d7dcbe..7a94ce4 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -26,6 +26,7 @@ with the pyarrow library.
 from __future__ import annotations
 
 import concurrent.futures
+import itertools
 import logging
 import os
 import re
@@ -34,7 +35,6 @@ from concurrent.futures import Future
 from dataclasses import dataclass
 from enum import Enum
 from functools import lru_cache, singledispatch
-from itertools import chain
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -637,7 +637,7 @@ def _combine_positional_deletes(positional_deletes: 
List[pa.ChunkedArray], rows:
     if len(positional_deletes) == 1:
         all_chunks = positional_deletes[0]
     else:
-        all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in 
positional_deletes]))
+        all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in 
positional_deletes]))
     return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
 
 
@@ -912,6 +912,21 @@ class 
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
         self._field_names.pop()
 
 
+class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
+    """
+    Converts PyArrowSchema to Iceberg Schema with all -1 ids.
+
+    The schema generated through this visitor should always be
+    used in conjunction with `new_table_metadata` function to
+    assign new field ids in order. This is currently used only
+    when creating an Iceberg Schema from a PyArrow schema when
+    creating a new Iceberg table.
+    """
+
+    def _field_id(self, field: pa.Field) -> int:
+        return -1
+
+
 def _task_to_table(
     fs: FileSystem,
     task: FileScanTask,
@@ -999,7 +1014,7 @@ def _task_to_table(
 
 def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> 
Dict[str, List[ChunkedArray]]:
     deletes_per_file: Dict[str, List[ChunkedArray]] = {}
-    unique_deletes = set(chain.from_iterable([task.delete_files for task in 
tasks]))
+    unique_deletes = set(itertools.chain.from_iterable([task.delete_files for 
task in tasks]))
     if len(unique_deletes) > 0:
         executor = ExecutorFactory.get_or_create()
         deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
@@ -1421,7 +1436,7 @@ class 
PyArrowStatisticsCollector(PreOrderSchemaVisitor[List[StatisticsCollector]
     def struct(
         self, struct: StructType, field_results: List[Callable[[], 
List[StatisticsCollector]]]
     ) -> List[StatisticsCollector]:
-        return list(chain(*[result() for result in field_results]))
+        return list(itertools.chain(*[result() for result in field_results]))
 
     def field(self, field: NestedField, field_result: Callable[[], 
List[StatisticsCollector]]) -> List[StatisticsCollector]:
         self._field_id = field.field_id
@@ -1513,7 +1528,7 @@ class 
ID2ParquetPathVisitor(PreOrderSchemaVisitor[List[ID2ParquetPath]]):
         return struct_result()
 
     def struct(self, struct: StructType, field_results: List[Callable[[], 
List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
-        return list(chain(*[result() for result in field_results]))
+        return list(itertools.chain(*[result() for result in field_results]))
 
     def field(self, field: NestedField, field_result: Callable[[], 
List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
         self._field_id = field.field_id
diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py
index b61e467..6dd174f 100644
--- a/pyiceberg/schema.py
+++ b/pyiceberg/schema.py
@@ -1221,50 +1221,57 @@ def assign_fresh_schema_ids(schema_or_type: 
Union[Schema, IcebergType], next_id:
 class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
     """Traverses the schema and assigns monotonically increasing ids."""
 
-    reserved_ids: Dict[int, int]
+    old_id_to_new_id: Dict[int, int]
 
     def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> 
None:
-        self.reserved_ids = {}
+        self.old_id_to_new_id = {}
         counter = itertools.count(1)
         self.next_id_func = next_id_func if next_id_func is not None else 
lambda: next(counter)
 
-    def _get_and_increment(self) -> int:
-        return self.next_id_func()
+    def _get_and_increment(self, current_id: int) -> int:
+        new_id = self.next_id_func()
+        self.old_id_to_new_id[current_id] = new_id
+        return new_id
 
     def schema(self, schema: Schema, struct_result: Callable[[], StructType]) 
-> Schema:
-        # First we keep the original identifier_field_ids here, we remap 
afterwards
-        fields = struct_result().fields
-        return Schema(*fields, 
identifier_field_ids=[self.reserved_ids[field_id] for field_id in 
schema.identifier_field_ids])
+        return Schema(
+            *struct_result().fields,
+            identifier_field_ids=[self.old_id_to_new_id[field_id] for field_id 
in schema.identifier_field_ids],
+        )
 
     def struct(self, struct: StructType, field_results: List[Callable[[], 
IcebergType]]) -> StructType:
-        # assign IDs for this struct's fields first
-        self.reserved_ids.update({field.field_id: self._get_and_increment() 
for field in struct.fields})
-        return StructType(*[field() for field in field_results])
+        new_ids = [self._get_and_increment(field.field_id) for field in 
struct.fields]
+        new_fields = []
+        for field_id, field, field_type in zip(new_ids, struct.fields, 
field_results):
+            new_fields.append(
+                NestedField(
+                    field_id=field_id,
+                    name=field.name,
+                    field_type=field_type(),
+                    required=field.required,
+                    doc=field.doc,
+                )
+            )
+        return StructType(*new_fields)
 
     def field(self, field: NestedField, field_result: Callable[[], 
IcebergType]) -> IcebergType:
-        return NestedField(
-            field_id=self.reserved_ids[field.field_id],
-            name=field.name,
-            field_type=field_result(),
-            required=field.required,
-            doc=field.doc,
-        )
+        return field_result()
 
     def list(self, list_type: ListType, element_result: Callable[[], 
IcebergType]) -> ListType:
-        self.reserved_ids[list_type.element_id] = self._get_and_increment()
+        element_id = self._get_and_increment(list_type.element_id)
         return ListType(
-            element_id=self.reserved_ids[list_type.element_id],
+            element_id=element_id,
             element=element_result(),
             element_required=list_type.element_required,
         )
 
     def map(self, map_type: MapType, key_result: Callable[[], IcebergType], 
value_result: Callable[[], IcebergType]) -> MapType:
-        self.reserved_ids[map_type.key_id] = self._get_and_increment()
-        self.reserved_ids[map_type.value_id] = self._get_and_increment()
+        key_id = self._get_and_increment(map_type.key_id)
+        value_id = self._get_and_increment(map_type.value_id)
         return MapType(
-            key_id=self.reserved_ids[map_type.key_id],
+            key_id=key_id,
             key_type=key_result(),
-            value_id=self.reserved_ids[map_type.value_id],
+            value_id=value_id,
             value_type=value_result(),
             value_required=map_type.value_required,
         )
diff --git a/pyproject.toml b/pyproject.toml
index e7f18b5..d1bc82d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -311,7 +311,7 @@ select = [
     "I", # isort
     "UP", # pyupgrade
 ]
-ignore = ["E501","E203","B024","B028"]
+ignore = ["E501","E203","B024","B028","UP037"]
 
 # Allow autofix for all enabled rules (when `--fix`) is provided.
 fixable = ["ALL"]
diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py
index 911c06b..d15c90f 100644
--- a/tests/catalog/test_base.py
+++ b/tests/catalog/test_base.py
@@ -24,7 +24,9 @@ from typing import (
     Union,
 )
 
+import pyarrow as pa
 import pytest
+from pytest_lazyfixture import lazy_fixture
 
 from pyiceberg.catalog import (
     Catalog,
@@ -72,12 +74,14 @@ class InMemoryCatalog(Catalog):
     def create_table(
         self,
         identifier: Union[str, Identifier],
-        schema: Schema,
+        schema: Union[Schema, "pa.Schema"],
         location: Optional[str] = None,
         partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
         sort_order: SortOrder = UNSORTED_SORT_ORDER,
         properties: Properties = EMPTY_DICT,
     ) -> Table:
+        schema: Schema = self._convert_schema_if_needed(schema)  # type: ignore
+
         identifier = Catalog.identifier_to_tuple(identifier)
         namespace = Catalog.namespace_from(identifier)
 
@@ -330,6 +334,34 @@ def test_create_table(catalog: InMemoryCatalog) -> None:
     assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table
 
 
+@pytest.mark.parametrize(
+    "schema,expected",
+    [
+        (lazy_fixture("pyarrow_schema_simple_without_ids"), 
lazy_fixture("iceberg_schema_simple_no_ids")),
+        (lazy_fixture("iceberg_schema_simple"), 
lazy_fixture("iceberg_schema_simple")),
+        (lazy_fixture("iceberg_schema_nested"), 
lazy_fixture("iceberg_schema_nested")),
+        (lazy_fixture("pyarrow_schema_nested_without_ids"), 
lazy_fixture("iceberg_schema_nested_no_ids")),
+    ],
+)
+def test_convert_schema_if_needed(
+    schema: Union[Schema, pa.Schema],
+    expected: Schema,
+    catalog: InMemoryCatalog,
+) -> None:
+    assert expected == catalog._convert_schema_if_needed(schema)
+
+
+def test_create_table_pyarrow_schema(catalog: InMemoryCatalog, 
pyarrow_schema_simple_without_ids: pa.Schema) -> None:
+    table = catalog.create_table(
+        identifier=TEST_TABLE_IDENTIFIER,
+        schema=pyarrow_schema_simple_without_ids,
+        location=TEST_TABLE_LOCATION,
+        partition_spec=TEST_TABLE_PARTITION_SPEC,
+        properties=TEST_TABLE_PROPERTIES,
+    )
+    assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table
+
+
 def test_create_table_raises_error_when_table_already_exists(catalog: 
InMemoryCatalog) -> None:
     # Given
     given_catalog_has_a_table(catalog)
diff --git a/tests/catalog/test_dynamodb.py b/tests/catalog/test_dynamodb.py
index 5af89ef..bc80146 100644
--- a/tests/catalog/test_dynamodb.py
+++ b/tests/catalog/test_dynamodb.py
@@ -18,6 +18,7 @@ from typing import List
 from unittest import mock
 
 import boto3
+import pyarrow as pa
 import pytest
 from moto import mock_dynamodb
 
@@ -71,6 +72,23 @@ def test_create_table_with_database_location(
     assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
 
 
+@mock_dynamodb
+def test_create_table_with_pyarrow_schema(
+    _bucket_initialize: None,
+    moto_endpoint_url: str,
+    pyarrow_schema_simple_without_ids: pa.Schema,
+    database_name: str,
+    table_name: str,
+) -> None:
+    catalog_name = "test_ddb_catalog"
+    identifier = (database_name, table_name)
+    test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": 
moto_endpoint_url})
+    test_catalog.create_namespace(namespace=database_name, 
properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"})
+    table = test_catalog.create_table(identifier, 
pyarrow_schema_simple_without_ids)
+    assert table.identifier == (catalog_name,) + identifier
+    assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+
+
 @mock_dynamodb
 def test_create_table_with_default_warehouse(
     _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: 
Schema, database_name: str, table_name: str
diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py
index b1f1371..63a213f 100644
--- a/tests/catalog/test_glue.py
+++ b/tests/catalog/test_glue.py
@@ -18,6 +18,7 @@ from typing import Any, Dict, List
 from unittest import mock
 
 import boto3
+import pyarrow as pa
 import pytest
 from moto import mock_glue
 
@@ -101,6 +102,28 @@ def test_create_table_with_given_location(
     assert test_catalog._parse_metadata_version(table.metadata_location) == 0
 
 
+@mock_glue
+def test_create_table_with_pyarrow_schema(
+    _bucket_initialize: None,
+    moto_endpoint_url: str,
+    pyarrow_schema_simple_without_ids: pa.Schema,
+    database_name: str,
+    table_name: str,
+) -> None:
+    catalog_name = "glue"
+    identifier = (database_name, table_name)
+    test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": 
moto_endpoint_url})
+    test_catalog.create_namespace(namespace=database_name)
+    table = test_catalog.create_table(
+        identifier=identifier,
+        schema=pyarrow_schema_simple_without_ids,
+        location=f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}",
+    )
+    assert table.identifier == (catalog_name,) + identifier
+    assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+    assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+
+
 @mock_glue
 def test_create_table_with_no_location(
     _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: 
Schema, database_name: str, table_name: str
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 9dbcf8f..1ca8fd1 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -158,6 +158,26 @@ def test_create_table_default_sort_order(catalog: 
SqlCatalog, table_schema_neste
     catalog.drop_table(random_identifier)
 
 
+@pytest.mark.parametrize(
+    'catalog',
+    [
+        lazy_fixture('catalog_memory'),
+        lazy_fixture('catalog_sqlite'),
+    ],
+)
+def test_create_table_with_pyarrow_schema(
+    catalog: SqlCatalog,
+    pyarrow_schema_simple_without_ids: pa.Schema,
+    iceberg_table_schema_simple: Schema,
+    random_identifier: Identifier,
+) -> None:
+    database_name, _table_name = random_identifier
+    catalog.create_namespace(database_name)
+    table = catalog.create_table(random_identifier, 
pyarrow_schema_simple_without_ids)
+    assert table.schema() == iceberg_table_schema_simple
+    catalog.drop_table(random_identifier)
+
+
 @pytest.mark.parametrize(
     'catalog',
     [
diff --git a/tests/conftest.py b/tests/conftest.py
index 9c53301..d9a8dfd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -45,6 +45,7 @@ from typing import (
 from urllib.parse import urlparse
 
 import boto3
+import pyarrow as pa
 import pytest
 from moto import mock_dynamodb, mock_glue
 from moto.server import ThreadedMotoServer  # type: ignore
@@ -267,6 +268,178 @@ def table_schema_nested_with_struct_key_map() -> Schema:
     )
 
 
+@pytest.fixture(scope="session")
+def pyarrow_schema_simple_without_ids() -> pa.Schema:
+    return pa.schema([
+        pa.field('foo', pa.string(), nullable=True),
+        pa.field('bar', pa.int32(), nullable=False),
+        pa.field('baz', pa.bool_(), nullable=True),
+    ])
+
+
+@pytest.fixture(scope="session")
+def pyarrow_schema_nested_without_ids() -> pa.Schema:
+    return pa.schema([
+        pa.field('foo', pa.string(), nullable=False),
+        pa.field('bar', pa.int32(), nullable=False),
+        pa.field('baz', pa.bool_(), nullable=True),
+        pa.field('qux', pa.list_(pa.string()), nullable=False),
+        pa.field(
+            'quux',
+            pa.map_(
+                pa.string(),
+                pa.map_(pa.string(), pa.int32()),
+            ),
+            nullable=False,
+        ),
+        pa.field(
+            'location',
+            pa.list_(
+                pa.struct([
+                    pa.field('latitude', pa.float32(), nullable=False),
+                    pa.field('longitude', pa.float32(), nullable=False),
+                ]),
+            ),
+            nullable=False,
+        ),
+        pa.field(
+            'person',
+            pa.struct([
+                pa.field('name', pa.string(), nullable=True),
+                pa.field('age', pa.int32(), nullable=False),
+            ]),
+            nullable=True,
+        ),
+    ])
+
+
+@pytest.fixture(scope="session")
+def iceberg_schema_simple() -> Schema:
+    return Schema(
+        NestedField(field_id=1, name="foo", field_type=StringType(), 
required=False),
+        NestedField(field_id=2, name="bar", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name="baz", field_type=BooleanType(), 
required=False),
+    )
+
+
+@pytest.fixture(scope="session")
+def iceberg_schema_simple_no_ids() -> Schema:
+    return Schema(
+        NestedField(field_id=-1, name="foo", field_type=StringType(), 
required=False),
+        NestedField(field_id=-1, name="bar", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=-1, name="baz", field_type=BooleanType(), 
required=False),
+    )
+
+
+@pytest.fixture(scope="session")
+def iceberg_table_schema_simple() -> Schema:
+    return Schema(
+        NestedField(field_id=1, name="foo", field_type=StringType(), 
required=False),
+        NestedField(field_id=2, name="bar", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name="baz", field_type=BooleanType(), 
required=False),
+        schema_id=0,
+        identifier_field_ids=[],
+    )
+
+
+@pytest.fixture(scope="session")
+def iceberg_schema_nested() -> Schema:
+    return Schema(
+        NestedField(field_id=1, name="foo", field_type=StringType(), 
required=True),
+        NestedField(field_id=2, name="bar", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name="baz", field_type=BooleanType(), 
required=False),
+        NestedField(
+            field_id=4,
+            name="qux",
+            field_type=ListType(element_id=5, element_type=StringType(), 
element_required=False),
+            required=True,
+        ),
+        NestedField(
+            field_id=6,
+            name="quux",
+            field_type=MapType(
+                key_id=7,
+                key_type=StringType(),
+                value_id=8,
+                value_type=MapType(key_id=9, key_type=StringType(), 
value_id=10, value_type=IntegerType(), value_required=False),
+                value_required=False,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=11,
+            name="location",
+            field_type=ListType(
+                element_id=12,
+                element_type=StructType(
+                    NestedField(field_id=13, name="latitude", 
field_type=FloatType(), required=True),
+                    NestedField(field_id=14, name="longitude", 
field_type=FloatType(), required=True),
+                ),
+                element_required=False,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=15,
+            name="person",
+            field_type=StructType(
+                NestedField(field_id=16, name="name", field_type=StringType(), 
required=False),
+                NestedField(field_id=17, name="age", field_type=IntegerType(), 
required=True),
+            ),
+            required=False,
+        ),
+    )
+
+
+@pytest.fixture(scope="session")
+def iceberg_schema_nested_no_ids() -> Schema:
+    return Schema(
+        NestedField(field_id=-1, name="foo", field_type=StringType(), 
required=True),
+        NestedField(field_id=-1, name="bar", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=-1, name="baz", field_type=BooleanType(), 
required=False),
+        NestedField(
+            field_id=-1,
+            name="qux",
+            field_type=ListType(element_id=-1, element_type=StringType(), 
element_required=False),
+            required=True,
+        ),
+        NestedField(
+            field_id=-1,
+            name="quux",
+            field_type=MapType(
+                key_id=-1,
+                key_type=StringType(),
+                value_id=-1,
+                value_type=MapType(key_id=-1, key_type=StringType(), 
value_id=-1, value_type=IntegerType(), value_required=False),
+                value_required=False,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=-1,
+            name="location",
+            field_type=ListType(
+                element_id=-1,
+                element_type=StructType(
+                    NestedField(field_id=-1, name="latitude", 
field_type=FloatType(), required=True),
+                    NestedField(field_id=-1, name="longitude", 
field_type=FloatType(), required=True),
+                ),
+                element_required=False,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=-1,
+            name="person",
+            field_type=StructType(
+                NestedField(field_id=-1, name="name", field_type=StringType(), 
required=False),
+                NestedField(field_id=-1, name="age", field_type=IntegerType(), 
required=True),
+            ),
+            required=False,
+        ),
+    )
+
+
 @pytest.fixture(scope="session")
 def all_avro_types() -> Dict[str, Any]:
     return {
diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py
index 0986eac..c7f364b 100644
--- a/tests/io/test_pyarrow_visitor.py
+++ b/tests/io/test_pyarrow_visitor.py
@@ -23,6 +23,7 @@ import pytest
 from pyiceberg.io.pyarrow import (
     _ConvertToArrowSchema,
     _ConvertToIceberg,
+    _ConvertToIcebergWithoutIDs,
     _HasIds,
     pyarrow_to_schema,
     schema_to_pyarrow,
@@ -51,104 +52,6 @@ from pyiceberg.types import (
 )
 
 
-@pytest.fixture(scope="module")
-def pyarrow_schema_simple_without_ids() -> pa.Schema:
-    return pa.schema([pa.field('some_int', pa.int32(), nullable=True), 
pa.field('some_string', pa.string(), nullable=False)])
-
-
-@pytest.fixture(scope="module")
-def pyarrow_schema_nested_without_ids() -> pa.Schema:
-    return pa.schema([
-        pa.field('foo', pa.string(), nullable=False),
-        pa.field('bar', pa.int32(), nullable=False),
-        pa.field('baz', pa.bool_(), nullable=True),
-        pa.field('qux', pa.list_(pa.string()), nullable=False),
-        pa.field(
-            'quux',
-            pa.map_(
-                pa.string(),
-                pa.map_(pa.string(), pa.int32()),
-            ),
-            nullable=False,
-        ),
-        pa.field(
-            'location',
-            pa.list_(
-                pa.struct([
-                    pa.field('latitude', pa.float32(), nullable=False),
-                    pa.field('longitude', pa.float32(), nullable=False),
-                ]),
-            ),
-            nullable=False,
-        ),
-        pa.field(
-            'person',
-            pa.struct([
-                pa.field('name', pa.string(), nullable=True),
-                pa.field('age', pa.int32(), nullable=False),
-            ]),
-            nullable=True,
-        ),
-    ])
-
-
-@pytest.fixture(scope="module")
-def iceberg_schema_simple() -> Schema:
-    return Schema(
-        NestedField(field_id=1, name="some_int", field_type=IntegerType(), 
required=False),
-        NestedField(field_id=2, name="some_string", field_type=StringType(), 
required=True),
-    )
-
-
-@pytest.fixture(scope="module")
-def iceberg_schema_nested() -> Schema:
-    return Schema(
-        NestedField(field_id=1, name="foo", field_type=StringType(), 
required=True),
-        NestedField(field_id=2, name="bar", field_type=IntegerType(), 
required=True),
-        NestedField(field_id=3, name="baz", field_type=BooleanType(), 
required=False),
-        NestedField(
-            field_id=4,
-            name="qux",
-            field_type=ListType(element_id=5, element_type=StringType(), 
element_required=False),
-            required=True,
-        ),
-        NestedField(
-            field_id=6,
-            name="quux",
-            field_type=MapType(
-                key_id=7,
-                key_type=StringType(),
-                value_id=8,
-                value_type=MapType(key_id=9, key_type=StringType(), 
value_id=10, value_type=IntegerType(), value_required=False),
-                value_required=False,
-            ),
-            required=True,
-        ),
-        NestedField(
-            field_id=11,
-            name="location",
-            field_type=ListType(
-                element_id=12,
-                element_type=StructType(
-                    NestedField(field_id=13, name="latitude", 
field_type=FloatType(), required=True),
-                    NestedField(field_id=14, name="longitude", 
field_type=FloatType(), required=True),
-                ),
-                element_required=False,
-            ),
-            required=True,
-        ),
-        NestedField(
-            field_id=15,
-            name="person",
-            field_type=StructType(
-                NestedField(field_id=16, name="name", field_type=StringType(), 
required=False),
-                NestedField(field_id=17, name="age", field_type=IntegerType(), 
required=True),
-            ),
-            required=False,
-        ),
-    )
-
-
 def test_pyarrow_binary_to_iceberg() -> None:
     length = 23
     pyarrow_type = pa.binary(length)
@@ -468,8 +371,9 @@ def 
test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping(
 ) -> None:
     schema = pyarrow_schema_simple_without_ids
     name_mapping = NameMapping([
-        MappedField(field_id=1, names=['some_int']),
-        MappedField(field_id=2, names=['some_string']),
+        MappedField(field_id=1, names=['foo']),
+        MappedField(field_id=2, names=['bar']),
+        MappedField(field_id=3, names=['baz']),
     ])
 
     assert pyarrow_to_schema(schema, name_mapping) == iceberg_schema_simple
@@ -480,11 +384,11 @@ def 
test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping_partial_
 ) -> None:
     schema = pyarrow_schema_simple_without_ids
     name_mapping = NameMapping([
-        MappedField(field_id=1, names=['some_string']),
+        MappedField(field_id=1, names=['foo']),
     ])
     with pytest.raises(ValueError) as exc_info:
         _ = pyarrow_to_schema(schema, name_mapping)
-    assert "Could not find field with name: some_int" in str(exc_info.value)
+    assert "Could not find field with name: bar" in str(exc_info.value)
 
 
 def test_nested_pyarrow_schema_to_schema_missing_ids_using_name_mapping(
@@ -572,3 +476,15 @@ def 
test_pyarrow_schema_to_schema_missing_ids_using_name_mapping_nested_missing_
     with pytest.raises(ValueError) as exc_info:
         _ = pyarrow_to_schema(schema, name_mapping)
     assert "Could not find field with name: quux.value.key" in 
str(exc_info.value)
+
+
+def test_pyarrow_schema_to_schema_fresh_ids_simple_schema(
+    pyarrow_schema_simple_without_ids: pa.Schema, 
iceberg_schema_simple_no_ids: Schema
+) -> None:
+    assert visit_pyarrow(pyarrow_schema_simple_without_ids, 
_ConvertToIcebergWithoutIDs()) == iceberg_schema_simple_no_ids
+
+
+def test_pyarrow_schema_to_schema_fresh_ids_nested_schema(
+    pyarrow_schema_nested_without_ids: pa.Schema, 
iceberg_schema_nested_no_ids: Schema
+) -> None:
+    assert visit_pyarrow(pyarrow_schema_nested_without_ids, 
_ConvertToIcebergWithoutIDs()) == iceberg_schema_nested_no_ids


Reply via email to