This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git
The following commit(s) were added to refs/heads/main by this push:
new 583bf96 Fix: BytesReader flush ArrowFormatWriter and pass schema (#9)
583bf96 is described below
commit 583bf96739fba9d4d1c60f7664e08394c3ca3764
Author: yuzelin <[email protected]>
AuthorDate: Tue Sep 3 16:51:32 2024 +0800
Fix: BytesReader flush ArrowFormatWriter and pass schema (#9)
---
java_based_implementation/api_impl.py | 27 +++++++---------------
.../java/org/apache/paimon/python/BytesReader.java | 10 +++++++-
.../tests/test_write_and_read.py | 7 +++---
3 files changed, 20 insertions(+), 24 deletions(-)
diff --git a/java_based_implementation/api_impl.py
b/java_based_implementation/api_impl.py
index 6d53289..2605cc3 100644
--- a/java_based_implementation/api_impl.py
+++ b/java_based_implementation/api_impl.py
@@ -16,8 +16,6 @@
# limitations under the License.
################################################################################
-import itertools
-
from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.util.java_utils import to_j_catalog_context,
check_batch_write
from paimon_python_api import (catalog, table, read_builder, table_scan,
split, table_read,
@@ -123,15 +121,15 @@ class TableRead(table_read.TableRead):
def create_reader(self, split: Split):
self._j_bytes_reader.setSplit(split.to_j_split())
- batch_iterator = self._batch_generator()
- # to init arrow schema
- try:
- first_batch = next(batch_iterator)
- except StopIteration:
- return self._empty_batch_reader()
+ # get schema
+ if self._arrow_schema is None:
+ schema_bytes = self._j_bytes_reader.serializeSchema()
+ schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes))
+ self._arrow_schema = schema_reader.schema
+ schema_reader.close()
- batches = itertools.chain((b for b in [first_batch]), batch_iterator)
- return RecordBatchReader.from_batches(self._arrow_schema, batches)
+ batch_iterator = self._batch_generator()
+ return RecordBatchReader.from_batches(self._arrow_schema,
batch_iterator)
def _batch_generator(self) -> Iterator[RecordBatch]:
while True:
@@ -140,17 +138,8 @@ class TableRead(table_read.TableRead):
break
else:
stream_reader =
RecordBatchStreamReader(BufferReader(next_bytes))
- if self._arrow_schema is None:
- self._arrow_schema = stream_reader.schema
yield from stream_reader
- def _empty_batch_reader(self):
- import pyarrow as pa
- schema = pa.schema([])
- empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
- empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
- return empty_reader
-
class BatchWriteBuilder(write_builder.BatchWriteBuilder):
diff --git
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
index 45be1d5..9272f98 100644
---
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
+++
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
@@ -47,7 +47,7 @@ public class BytesReader {
public BytesReader(TableRead tableRead, RowType rowType) {
this.tableRead = tableRead;
- this.arrowFormatWriter = new ArrowFormatWriter(rowType,
DEFAULT_WRITE_BATCH_SIZE);
+ this.arrowFormatWriter = new ArrowFormatWriter(rowType,
DEFAULT_WRITE_BATCH_SIZE, true);
}
public void setSplit(Split split) throws IOException {
@@ -56,6 +56,13 @@ public class BytesReader {
nextRow();
}
+ public byte[] serializeSchema() {
+ VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ ArrowUtils.serializeToIpc(vsr, out);
+ return out.toByteArray();
+ }
+
@Nullable
public byte[] next() throws Exception {
if (nextRow == null) {
@@ -68,6 +75,7 @@ public class BytesReader {
rowCount++;
}
+ arrowFormatWriter.flush();
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
vsr.setRowCount(rowCount);
ByteArrayOutputStream out = new ByteArrayOutputStream();
diff --git a/java_based_implementation/tests/test_write_and_read.py
b/java_based_implementation/tests/test_write_and_read.py
index 907741d..02ead23 100644
--- a/java_based_implementation/tests/test_write_and_read.py
+++ b/java_based_implementation/tests/test_write_and_read.py
@@ -92,8 +92,7 @@ class TableWriteReadTest(unittest.TestCase):
for split in splits
for batch in table_read.create_reader(split)
]
- result = pd.concat(data_frames)
- self.assertEqual(result.shape, (0, 0))
+ self.assertEqual(len(data_frames), 0)
def testWriteReadAppendTable(self):
create_simple_table(self.warehouse, 'default', 'simple_append_table',
False)
@@ -135,8 +134,8 @@ class TableWriteReadTest(unittest.TestCase):
]
result = pd.concat(data_frames)
- # check data
- pd.testing.assert_frame_equal(result, df)
+ # check data (ignore index)
+ pd.testing.assert_frame_equal(result.reset_index(drop=True),
df.reset_index(drop=True))
def testWriteWrongSchema(self):
create_simple_table(self.warehouse, 'default', 'test_wrong_schema',
False)