This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new acc934f Check the types when writing (#313)
acc934f is described below
commit acc934fb76aa6c6e2e32b60c8a99f9e2b2c627dd
Author: Fokko Driesprong <[email protected]>
AuthorDate: Mon Jan 29 00:21:15 2024 +0100
Check the types when writing (#313)
---
pyiceberg/table/__init__.py | 16 ++++++++++++++++
tests/integration/test_writes.py | 18 ++++++++++++++++++
2 files changed, 34 insertions(+)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 221a609..26eecef 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -932,6 +932,14 @@ class Table:
Args:
df: The Arrow dataframe that will be appended to overwrite the
table
"""
+ try:
+ import pyarrow as pa
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
+
+ if not isinstance(df, pa.Table):
+ raise ValueError(f"Expected PyArrow table, got: {df}")
+
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
@@ -954,6 +962,14 @@ class Table:
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial
overwrite
"""
+ try:
+ import pyarrow as pa
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
+
+ if not isinstance(df, pa.Table):
+ raise ValueError(f"Expected PyArrow table, got: {df}")
+
if overwrite_filter != AlwaysTrue():
raise NotImplementedError("Cannot overwrite a subset of a table")
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index a095c13..17dc997 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -391,3 +391,21 @@ def test_data_files(spark: SparkSession, session_catalog:
Catalog, arrow_table_w
assert [row.added_data_files_count for row in rows] == [1, 1, 0, 1, 1]
assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0]
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]
+
+
[email protected]
+def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog,
arrow_table_with_null: pa.Table) -> None:
+ identifier = "default.arrow_data_files"
+
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+
+ tbl = session_catalog.create_table(identifier=identifier,
schema=TABLE_SCHEMA, properties={'format-version': '1'})
+
+ with pytest.raises(ValueError, match="Expected PyArrow table, got: not a
df"):
+ tbl.overwrite("not a df")
+
+ with pytest.raises(ValueError, match="Expected PyArrow table, got: not a
df"):
+ tbl.append("not a df")