This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 00af9ad  Fix that field nullability affects write (#24)
00af9ad is described below

commit 00af9ad1dae60963a0cfce1c19225b97c413faf6
Author: yuzelin <[email protected]>
AuthorDate: Tue Nov 12 13:53:06 2024 +0800

    Fix that field nullability affects write (#24)
---
 .../java/org/apache/paimon/python/BytesWriter.java | 37 +++++++++++
 paimon_python_java/pypaimon.py                     | 21 +++---
 paimon_python_java/tests/test_write_and_read.py    | 74 ++++++++++++++++++++++
 3 files changed, 121 insertions(+), 11 deletions(-)

diff --git 
a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
 
b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
index 7cf6267..f2ca4e1 100644
--- 
a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
+++ 
b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.paimon.python;
 
+import org.apache.paimon.arrow.ArrowUtils;
 import org.apache.paimon.arrow.reader.ArrowBatchReader;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.table.sink.TableWrite;
@@ -27,8 +28,11 @@ import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
+import org.apache.arrow.vector.types.pojo.Field;
 
 import java.io.ByteArrayInputStream;
+import java.util.List;
+import java.util.stream.Collectors;
 
 /** Write Arrow bytes to Paimon. */
 public class BytesWriter {
@@ -36,17 +40,30 @@ public class BytesWriter {
     private final TableWrite tableWrite;
     private final ArrowBatchReader arrowBatchReader;
     private final BufferAllocator allocator;
+    private final List<Field> arrowFields;
 
     public BytesWriter(TableWrite tableWrite, RowType rowType) {
         this.tableWrite = tableWrite;
         this.arrowBatchReader = new ArrowBatchReader(rowType);
         this.allocator = new RootAllocator();
+        arrowFields =
+                rowType.getFields().stream()
+                        .map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
+                        .collect(Collectors.toList());
     }
 
     public void write(byte[] bytes) throws Exception {
         ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
         ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, 
allocator);
         VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
+        if (!checkTypesIgnoreNullability(arrowFields, 
vsr.getSchema().getFields())) {
+            throw new RuntimeException(
+                    String.format(
+                            "Input schema isn't consistent with table 
schema.\n"
+                                    + "\tTable schema is: %s\n"
+                                    + "\tInput schema is: %s",
+                            arrowFields, vsr.getSchema().getFields()));
+        }
 
         while (arrowStreamReader.loadNextBatch()) {
             Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
@@ -60,4 +77,24 @@ public class BytesWriter {
     public void close() {
         allocator.close();
     }
+
+    private boolean checkTypesIgnoreNullability(
+            List<Field> expectedFields, List<Field> actualFields) {
+        if (expectedFields.size() != actualFields.size()) {
+            return false;
+        }
+
+        for (int i = 0; i < expectedFields.size(); i++) {
+            Field expectedField = expectedFields.get(i);
+            Field actualField = actualFields.get(i);
+            // ArrowType doesn't have nullability (similar to DataTypeRoot)
+            if (!actualField.getType().equals(expectedField.getType())
+                    || !checkTypesIgnoreNullability(
+                            expectedField.getChildren(), 
actualField.getChildren())) {
+                return false;
+            }
+        }
+
+        return true;
+    }
 }
diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py
index 0d3101b..16c7a69 100644
--- a/paimon_python_java/pypaimon.py
+++ b/paimon_python_java/pypaimon.py
@@ -218,24 +218,23 @@ class BatchTableWrite(table_write.BatchTableWrite):
 
     def write_arrow(self, table):
         for record_batch in table.to_reader():
-            # TODO: can we use a reusable stream?
-            stream = pa.BufferOutputStream()
-            with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as 
writer:
-                writer.write(record_batch)
-            arrow_bytes = stream.getvalue().to_pybytes()
-            self._j_bytes_writer.write(arrow_bytes)
+            # TODO: can we use a reusable stream in #_write_arrow_batch ?
+            self._write_arrow_batch(record_batch)
 
     def write_arrow_batch(self, record_batch):
+        self._write_arrow_batch(record_batch)
+
+    def write_pandas(self, dataframe: pd.DataFrame):
+        record_batch = pa.RecordBatch.from_pandas(dataframe, 
schema=self._arrow_schema)
+        self._write_arrow_batch(record_batch)
+
+    def _write_arrow_batch(self, record_batch):
         stream = pa.BufferOutputStream()
-        with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
+        with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer:
             writer.write(record_batch)
         arrow_bytes = stream.getvalue().to_pybytes()
         self._j_bytes_writer.write(arrow_bytes)
 
-    def write_pandas(self, dataframe: pd.DataFrame):
-        record_batch = pa.RecordBatch.from_pandas(dataframe, 
schema=self._arrow_schema)
-        self.write_arrow_batch(record_batch)
-
     def prepare_commit(self) -> List['CommitMessage']:
         j_commit_messages = self._j_batch_table_write.prepareCommit()
         return list(map(lambda cm: CommitMessage(cm), j_commit_messages))
diff --git a/paimon_python_java/tests/test_write_and_read.py 
b/paimon_python_java/tests/test_write_and_read.py
index e1ea72b..b468e9f 100644
--- a/paimon_python_java/tests/test_write_and_read.py
+++ b/paimon_python_java/tests/test_write_and_read.py
@@ -22,6 +22,7 @@ import tempfile
 import unittest
 import pandas as pd
 import pyarrow as pa
+from py4j.protocol import Py4JJavaError
 
 from paimon_python_api import Schema
 from paimon_python_java import Catalog
@@ -371,3 +372,76 @@ class TableWriteReadTest(unittest.TestCase):
         df2['f0'] = df2['f0'].astype('int32')
         pd.testing.assert_frame_equal(
             actual_df2.reset_index(drop=True), df2.reset_index(drop=True))
+
+    def testWriteWrongSchema(self):
+        schema = Schema(self.simple_pa_schema)
+        self.catalog.create_table('default.test_wrong_schema', schema, False)
+        table = self.catalog.get_table('default.test_wrong_schema')
+
+        data = {
+            'f0': [1, 2, 3],
+            'f1': ['a', 'b', 'c'],
+        }
+        df = pd.DataFrame(data)
+        schema = pa.schema([
+            ('f0', pa.int64()),
+            ('f1', pa.string())
+        ])
+        record_batch = pa.RecordBatch.from_pandas(df, schema)
+
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write()
+
+        with self.assertRaises(Py4JJavaError) as e:
+            table_write.write_arrow_batch(record_batch)
+        self.assertEqual(
+            str(e.exception.java_exception),
+            '''java.lang.RuntimeException: Input schema isn't consistent with 
table schema.
+\tTable schema is: [f0: Int(32, true), f1: Utf8]
+\tInput schema is: [f0: Int(64, true), f1: Utf8]''')
+
+    def testIgnoreNullable(self):
+        pa_schema1 = pa.schema([
+            ('f0', pa.int32(), False),
+            ('f1', pa.string())
+        ])
+
+        pa_schema2 = pa.schema([
+            ('f0', pa.int32()),
+            ('f1', pa.string())
+        ])
+
+        # write nullable to non-null
+        self._testIgnoreNullableImpl('test_ignore_nullable1', pa_schema1, 
pa_schema2)
+
+        # write non-null to nullable
+        self._testIgnoreNullableImpl('test_ignore_nullable2', pa_schema2, 
pa_schema1)
+
+    def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema):
+        schema = Schema(table_schema)
+        self.catalog.create_table(f'default.{table_name}', schema, False)
+        table = self.catalog.get_table(f'default.{table_name}')
+
+        data = {
+            'f0': [1, 2, 3],
+            'f1': ['a', 'b', 'c'],
+        }
+        df = pd.DataFrame(data)
+        record_batch = pa.RecordBatch.from_pandas(pd.DataFrame(data), 
data_schema)
+
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        table_write.write_arrow_batch(record_batch)
+        table_commit.commit(table_write.prepare_commit())
+
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        actual_df = table_read.to_pandas(table_scan.plan().splits())
+        df['f0'] = df['f0'].astype('int32')
+        pd.testing.assert_frame_equal(
+            actual_df.reset_index(drop=True), df.reset_index(drop=True))

Reply via email to