This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new acc934f  Check the types when writing (#313)
acc934f is described below

commit acc934fb76aa6c6e2e32b60c8a99f9e2b2c627dd
Author: Fokko Driesprong <[email protected]>
AuthorDate: Mon Jan 29 00:21:15 2024 +0100

    Check the types when writing (#313)
---
 pyiceberg/table/__init__.py      | 16 ++++++++++++++++
 tests/integration/test_writes.py | 18 ++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 221a609..26eecef 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -932,6 +932,14 @@ class Table:
         Args:
             df: The Arrow dataframe that will be appended to overwrite the 
table
         """
+        try:
+            import pyarrow as pa
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
+
+        if not isinstance(df, pa.Table):
+            raise ValueError(f"Expected PyArrow table, got: {df}")
+
         if len(self.spec().fields) > 0:
             raise ValueError("Cannot write to partitioned tables")
 
@@ -954,6 +962,14 @@ class Table:
             overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
                               or a boolean expression in case of a partial 
overwrite
         """
+        try:
+            import pyarrow as pa
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
+
+        if not isinstance(df, pa.Table):
+            raise ValueError(f"Expected PyArrow table, got: {df}")
+
         if overwrite_filter != AlwaysTrue():
             raise NotImplementedError("Cannot overwrite a subset of a table")
 
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index a095c13..17dc997 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -391,3 +391,21 @@ def test_data_files(spark: SparkSession, session_catalog: 
Catalog, arrow_table_w
     assert [row.added_data_files_count for row in rows] == [1, 1, 0, 1, 1]
     assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0]
     assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]
+
+
[email protected]
+def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, 
arrow_table_with_null: pa.Table) -> None:
+    identifier = "default.arrow_data_files"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = session_catalog.create_table(identifier=identifier, 
schema=TABLE_SCHEMA, properties={'format-version': '1'})
+
+    with pytest.raises(ValueError, match="Expected PyArrow table, got: not a 
df"):
+        tbl.overwrite("not a df")
+
+    with pytest.raises(ValueError, match="Expected PyArrow table, got: not a 
df"):
+        tbl.append("not a df")

Reply via email to