This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 583bf96  Fix: BytesReader flush ArrowFormatWriter and pass schema (#9)
583bf96 is described below

commit 583bf96739fba9d4d1c60f7664e08394c3ca3764
Author: yuzelin <[email protected]>
AuthorDate: Tue Sep 3 16:51:32 2024 +0800

    Fix: BytesReader flush ArrowFormatWriter and pass schema (#9)
---
 java_based_implementation/api_impl.py              | 27 +++++++---------------
 .../java/org/apache/paimon/python/BytesReader.java | 10 +++++++-
 .../tests/test_write_and_read.py                   |  7 +++---
 3 files changed, 20 insertions(+), 24 deletions(-)

diff --git a/java_based_implementation/api_impl.py 
b/java_based_implementation/api_impl.py
index 6d53289..2605cc3 100644
--- a/java_based_implementation/api_impl.py
+++ b/java_based_implementation/api_impl.py
@@ -16,8 +16,6 @@
 # limitations under the License.
 
################################################################################
 
-import itertools
-
 from java_based_implementation.java_gateway import get_gateway
 from java_based_implementation.util.java_utils import to_j_catalog_context, 
check_batch_write
 from paimon_python_api import (catalog, table, read_builder, table_scan, 
split, table_read,
@@ -123,15 +121,15 @@ class TableRead(table_read.TableRead):
 
     def create_reader(self, split: Split):
         self._j_bytes_reader.setSplit(split.to_j_split())
-        batch_iterator = self._batch_generator()
-        # to init arrow schema
-        try:
-            first_batch = next(batch_iterator)
-        except StopIteration:
-            return self._empty_batch_reader()
+        # get schema
+        if self._arrow_schema is None:
+            schema_bytes = self._j_bytes_reader.serializeSchema()
+            schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes))
+            self._arrow_schema = schema_reader.schema
+            schema_reader.close()
 
-        batches = itertools.chain((b for b in [first_batch]), batch_iterator)
-        return RecordBatchReader.from_batches(self._arrow_schema, batches)
+        batch_iterator = self._batch_generator()
+        return RecordBatchReader.from_batches(self._arrow_schema, 
batch_iterator)
 
     def _batch_generator(self) -> Iterator[RecordBatch]:
         while True:
@@ -140,17 +138,8 @@ class TableRead(table_read.TableRead):
                 break
             else:
                 stream_reader = 
RecordBatchStreamReader(BufferReader(next_bytes))
-                if self._arrow_schema is None:
-                    self._arrow_schema = stream_reader.schema
                 yield from stream_reader
 
-    def _empty_batch_reader(self):
-        import pyarrow as pa
-        schema = pa.schema([])
-        empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
-        empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
-        return empty_reader
-
 
 class BatchWriteBuilder(write_builder.BatchWriteBuilder):
 
diff --git 
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
 
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
index 45be1d5..9272f98 100644
--- 
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
+++ 
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
@@ -47,7 +47,7 @@ public class BytesReader {
 
     public BytesReader(TableRead tableRead, RowType rowType) {
         this.tableRead = tableRead;
-        this.arrowFormatWriter = new ArrowFormatWriter(rowType, 
DEFAULT_WRITE_BATCH_SIZE);
+        this.arrowFormatWriter = new ArrowFormatWriter(rowType, 
DEFAULT_WRITE_BATCH_SIZE, true);
     }
 
     public void setSplit(Split split) throws IOException {
@@ -56,6 +56,13 @@ public class BytesReader {
         nextRow();
     }
 
+    public byte[] serializeSchema() {
+        VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+        ByteArrayOutputStream out = new ByteArrayOutputStream();
+        ArrowUtils.serializeToIpc(vsr, out);
+        return out.toByteArray();
+    }
+
     @Nullable
     public byte[] next() throws Exception {
         if (nextRow == null) {
@@ -68,6 +75,7 @@ public class BytesReader {
             rowCount++;
         }
 
+        arrowFormatWriter.flush();
         VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
         vsr.setRowCount(rowCount);
         ByteArrayOutputStream out = new ByteArrayOutputStream();
diff --git a/java_based_implementation/tests/test_write_and_read.py 
b/java_based_implementation/tests/test_write_and_read.py
index 907741d..02ead23 100644
--- a/java_based_implementation/tests/test_write_and_read.py
+++ b/java_based_implementation/tests/test_write_and_read.py
@@ -92,8 +92,7 @@ class TableWriteReadTest(unittest.TestCase):
             for split in splits
             for batch in table_read.create_reader(split)
         ]
-        result = pd.concat(data_frames)
-        self.assertEqual(result.shape, (0, 0))
+        self.assertEqual(len(data_frames), 0)
 
     def testWriteReadAppendTable(self):
         create_simple_table(self.warehouse, 'default', 'simple_append_table', 
False)
@@ -135,8 +134,8 @@ class TableWriteReadTest(unittest.TestCase):
         ]
         result = pd.concat(data_frames)
 
-        # check data
-        pd.testing.assert_frame_equal(result, df)
+        # check data (ignore index)
+        pd.testing.assert_frame_equal(result.reset_index(drop=True), 
df.reset_index(drop=True))
 
     def testWriteWrongSchema(self):
         create_simple_table(self.warehouse, 'default', 'test_wrong_schema', 
False)

Reply via email to