This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git


The following commit(s) were added to refs/heads/main by this push:
     new c505c28  Add table bucket mode check and data schema check for writing 
(#8)
c505c28 is described below

commit c505c283616c1b295eef46cefb19a365437d977e
Author: yuzelin <[email protected]>
AuthorDate: Tue Sep 3 16:00:54 2024 +0800

    Add table bucket mode check and data schema check for writing (#8)
---
 java_based_implementation/api_impl.py              |  3 +-
 .../java/org/apache/paimon/python/BytesWriter.java | 41 +++++++++++++++++++
 .../tests/test_write_and_read.py                   | 47 ++++++++++++++++++++++
 java_based_implementation/tests/utils.py           | 11 +++--
 java_based_implementation/util/java_utils.py       |  8 ++++
 5 files changed, 106 insertions(+), 4 deletions(-)

diff --git a/java_based_implementation/api_impl.py 
b/java_based_implementation/api_impl.py
index 609d181..6d53289 100644
--- a/java_based_implementation/api_impl.py
+++ b/java_based_implementation/api_impl.py
@@ -19,7 +19,7 @@
 import itertools
 
 from java_based_implementation.java_gateway import get_gateway
-from java_based_implementation.util.java_utils import to_j_catalog_context
+from java_based_implementation.util.java_utils import to_j_catalog_context, 
check_batch_write
 from paimon_python_api import (catalog, table, read_builder, table_scan, 
split, table_read,
                                write_builder, table_write, commit_message, 
table_commit)
 from pyarrow import (RecordBatch, BufferOutputStream, RecordBatchStreamWriter,
@@ -56,6 +56,7 @@ class Table(table.Table):
         return ReadBuilder(j_read_builder, self._j_table.rowType())
 
     def new_batch_write_builder(self) -> 'BatchWriteBuilder':
+        check_batch_write(self._j_table)
         j_batch_write_builder = 
get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
         return BatchWriteBuilder(j_batch_write_builder, 
self._j_table.rowType())
 
diff --git 
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
 
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
index 1f4ea63..7ad74c1 100644
--- 
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
+++ 
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.paimon.python;
 
+import org.apache.paimon.arrow.ArrowUtils;
 import org.apache.paimon.arrow.reader.ArrowBatchReader;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.table.sink.TableWrite;
@@ -27,8 +28,12 @@ import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
+import org.apache.arrow.vector.types.pojo.Field;
 
 import java.io.ByteArrayInputStream;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
 
 /** Write Arrow bytes to Paimon. */
 public class BytesWriter {
@@ -36,17 +41,31 @@ public class BytesWriter {
     private final TableWrite tableWrite;
     private final ArrowBatchReader arrowBatchReader;
     private final BufferAllocator allocator;
+    private final List<Field> arrowFields;
 
     public BytesWriter(TableWrite tableWrite, RowType rowType) {
         this.tableWrite = tableWrite;
         this.arrowBatchReader = new ArrowBatchReader(rowType);
         this.allocator = new RootAllocator();
+        arrowFields =
+                rowType.getFields().stream()
+                        .map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
+                        .collect(Collectors.toList());
     }
 
     public void write(byte[] bytes) throws Exception {
         ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
         ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, 
allocator);
         VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
+        if (!checkSchema(arrowFields, vsr.getSchema().getFields())) {
+            throw new RuntimeException(
+                    String.format(
+                            "Input schema isn't consistent with table 
schema.\n"
+                                    + "\tTable schema is: %s\n"
+                                    + "\tInput schema is: %s",
+                            arrowFields, vsr.getSchema().getFields()));
+        }
+
         while (arrowStreamReader.loadNextBatch()) {
             Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
             for (InternalRow row : rows) {
@@ -59,4 +78,26 @@ public class BytesWriter {
     public void close() {
         allocator.close();
     }
+
+    private boolean checkSchema(List<Field> expectedFields, List<Field> 
actualFields) {
+        if (expectedFields.size() != actualFields.size()) {
+            return false;
+        }
+
+        for (int i = 0; i < expectedFields.size(); i++) {
+            Field expectedField = expectedFields.get(i);
+            Field actualField = actualFields.get(i);
+            if (!checkField(expectedField, actualField)
+                    || !checkSchema(expectedField.getChildren(), 
actualField.getChildren())) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    private boolean checkField(Field expected, Field actual) {
+        return Objects.equals(expected.getName(), actual.getName())
+                && Objects.equals(expected.getType(), actual.getType());
+    }
 }
diff --git a/java_based_implementation/tests/test_write_and_read.py 
b/java_based_implementation/tests/test_write_and_read.py
index 2833514..907741d 100644
--- a/java_based_implementation/tests/test_write_and_read.py
+++ b/java_based_implementation/tests/test_write_and_read.py
@@ -26,6 +26,7 @@ from java_based_implementation.api_impl import Catalog, Table
 from java_based_implementation.java_gateway import get_gateway
 from java_based_implementation.tests.utils import set_bridge_jar, 
create_simple_table
 from java_based_implementation.util import constants, java_utils
+from py4j.protocol import Py4JJavaError
 
 
 class TableWriteReadTest(unittest.TestCase):
@@ -136,3 +137,49 @@ class TableWriteReadTest(unittest.TestCase):
 
         # check data
         pd.testing.assert_frame_equal(result, df)
+
+    def testWriteWrongSchema(self):
+        create_simple_table(self.warehouse, 'default', 'test_wrong_schema', 
False)
+
+        catalog = Catalog.create({'warehouse': self.warehouse})
+        table = catalog.get_table('default.test_wrong_schema')
+
+        data = {
+            'f0': [1, 2, 3],
+            'f1': ['a', 'b', 'c'],
+        }
+        df = pd.DataFrame(data)
+        schema = pa.schema([
+            ('f0', pa.int64()),
+            ('f1', pa.string())
+        ])
+        record_batch = pa.RecordBatch.from_pandas(df, schema)
+
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write()
+
+        with self.assertRaises(Py4JJavaError) as e:
+            table_write.write(record_batch)
+        self.assertEqual(
+            str(e.exception.java_exception),
+            '''java.lang.RuntimeException: Input schema isn't consistent with 
table schema.
+\tTable schema is: [f0: Int(32, true), f1: Utf8]
+\tInput schema is: [f0: Int(64, true), f1: Utf8]''')
+
+    def testCannotWriteDynamicBucketTable(self):
+        create_simple_table(
+            self.warehouse,
+            'default',
+            'test_dynamic_bucket',
+            True,
+            {'bucket': '-1'}
+        )
+
+        catalog = Catalog.create({'warehouse': self.warehouse})
+        table = catalog.get_table('default.test_dynamic_bucket')
+
+        with self.assertRaises(TypeError) as e:
+            table.new_batch_write_builder()
+        self.assertEqual(
+            str(e.exception),
+            "Doesn't support writing dynamic bucket or cross partition table.")
diff --git a/java_based_implementation/tests/utils.py 
b/java_based_implementation/tests/utils.py
index 86276f2..f600376 100644
--- a/java_based_implementation/tests/utils.py
+++ b/java_based_implementation/tests/utils.py
@@ -45,7 +45,13 @@ def set_bridge_jar() -> str:
     return os.path.join(temp_dir, jar_name)
 
 
-def create_simple_table(warehouse, database, table_name, has_pk):
+def create_simple_table(warehouse, database, table_name, has_pk, options=None):
+    if options is None:
+        options = {
+            'bucket': '1',
+            'bucket-key': 'f0'
+        }
+
     gateway = get_gateway()
 
     j_catalog_context = to_j_catalog_context({'warehouse': warehouse})
@@ -55,8 +61,7 @@ def create_simple_table(warehouse, database, table_name, 
has_pk):
         gateway.jvm.Schema.newBuilder()
         .column('f0', gateway.jvm.DataTypes.INT())
         .column('f1', gateway.jvm.DataTypes.STRING())
-        .option('bucket', '1')
-        .option('bucket-key', 'f0')
+        .options(options)
     )
     if has_pk:
         j_schema_builder.primaryKey(['f0'])
diff --git a/java_based_implementation/util/java_utils.py 
b/java_based_implementation/util/java_utils.py
index 83a3194..b9f2523 100644
--- a/java_based_implementation/util/java_utils.py
+++ b/java_based_implementation/util/java_utils.py
@@ -23,3 +23,11 @@ def to_j_catalog_context(catalog_context: dict):
     gateway = get_gateway()
     j_options = gateway.jvm.Options(catalog_context)
     return gateway.jvm.CatalogContext.create(j_options)
+
+
+def check_batch_write(j_table):
+    gateway = get_gateway()
+    bucket_mode = j_table.bucketMode()
+    if bucket_mode == gateway.jvm.BucketMode.HASH_DYNAMIC \
+            or bucket_mode == gateway.jvm.BucketMode.CROSS_PARTITION:
+        raise TypeError("Doesn't support writing dynamic bucket or cross 
partition table.")

Reply via email to