This is an automated email from the ASF dual-hosted git repository.
yuzelin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git
The following commit(s) were added to refs/heads/main by this push:
new 73366fb Complete table read and write (#6)
73366fb is described below
commit 73366fb6f84fc7b661d4441f3f49ca4abc0ffe92
Author: yuzelin <[email protected]>
AuthorDate: Tue Aug 20 17:24:57 2024 +0800
Complete table read and write (#6)
---
dev/dev-requirements.txt | 4 +
java_based_implementation/api_impl.py | 84 ++++++++++---
java_based_implementation/java_gateway.py | 7 +-
.../paimon-python-java-bridge/pom.xml | 46 +++++++
.../java/org/apache/paimon/python/BytesReader.java | 90 ++++++++++++++
.../java/org/apache/paimon/python/BytesWriter.java | 62 +++++++++
.../org/apache/paimon/python/InvocationUtil.java | 49 ++++++++
java_based_implementation/tests/test_table_scan.py | 25 ----
.../tests/test_write_and_read.py | 138 +++++++++++++++++++++
java_based_implementation/tests/utils.py | 30 ++++-
paimon_python_api/table_commit.py | 4 +
paimon_python_api/table_write.py | 4 +
12 files changed, 499 insertions(+), 44 deletions(-)
diff --git a/dev/dev-requirements.txt b/dev/dev-requirements.txt
index c9cfdbc..7fd1aeb 100755
--- a/dev/dev-requirements.txt
+++ b/dev/dev-requirements.txt
@@ -21,4 +21,8 @@ setuptools>=18.0
wheel
py4j==0.10.9.7
pyarrow>=5.0.0
+pandas>=1.3.0
+numpy>=1.22.4
+python-dateutil>=2.8.0,<3
+pytz>=2018.3
pytest~=7.0
diff --git a/java_based_implementation/api_impl.py
b/java_based_implementation/api_impl.py
index 8fcf473..609d181 100644
--- a/java_based_implementation/api_impl.py
+++ b/java_based_implementation/api_impl.py
@@ -16,12 +16,15 @@
# limitations under the License.
################################################################################
+import itertools
+
from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.util.java_utils import to_j_catalog_context
from paimon_python_api import (catalog, table, read_builder, table_scan,
split, table_read,
write_builder, table_write, commit_message,
table_commit)
-from pyarrow import RecordBatchReader, RecordBatch
-from typing import List
+from pyarrow import (RecordBatch, BufferOutputStream, RecordBatchStreamWriter,
+ RecordBatchStreamReader, BufferReader, RecordBatchReader)
+from typing import List, Iterator
class Catalog(catalog.Catalog):
@@ -49,18 +52,19 @@ class Table(table.Table):
self._j_table = j_table
def new_read_builder(self) -> 'ReadBuilder':
- j_read_builder = self._j_table.newReadBuilder()
- return ReadBuilder(j_read_builder)
+ j_read_builder =
get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
+ return ReadBuilder(j_read_builder, self._j_table.rowType())
def new_batch_write_builder(self) -> 'BatchWriteBuilder':
- j_batch_write_builder = self._j_table.newBatchWriteBuilder()
- return BatchWriteBuilder(j_batch_write_builder)
+ j_batch_write_builder =
get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
+ return BatchWriteBuilder(j_batch_write_builder,
self._j_table.rowType())
class ReadBuilder(read_builder.ReadBuilder):
- def __init__(self, j_read_builder):
+ def __init__(self, j_read_builder, j_row_type):
self._j_read_builder = j_read_builder
+ self._j_row_type = j_row_type
def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
self._j_read_builder.withProjection(projection)
@@ -75,8 +79,8 @@ class ReadBuilder(read_builder.ReadBuilder):
return TableScan(j_table_scan)
def new_read(self) -> 'TableRead':
- # TODO
- pass
+ j_table_read = self._j_read_builder.newRead()
+ return TableRead(j_table_read, self._j_row_type)
class TableScan(table_scan.TableScan):
@@ -110,15 +114,48 @@ class Split(split.Split):
class TableRead(table_read.TableRead):
- def create_reader(self, split: Split) -> RecordBatchReader:
- # TODO
- pass
+ def __init__(self, j_table_read, j_row_type):
+ self._j_table_read = j_table_read
+ self._j_bytes_reader =
get_gateway().jvm.InvocationUtil.createBytesReader(
+ j_table_read, j_row_type)
+ self._arrow_schema = None
+
+ def create_reader(self, split: Split):
+ self._j_bytes_reader.setSplit(split.to_j_split())
+ batch_iterator = self._batch_generator()
+ # to init arrow schema
+ try:
+ first_batch = next(batch_iterator)
+ except StopIteration:
+ return self._empty_batch_reader()
+
+ batches = itertools.chain((b for b in [first_batch]), batch_iterator)
+ return RecordBatchReader.from_batches(self._arrow_schema, batches)
+
+ def _batch_generator(self) -> Iterator[RecordBatch]:
+ while True:
+ next_bytes = self._j_bytes_reader.next()
+ if next_bytes is None:
+ break
+ else:
+ stream_reader =
RecordBatchStreamReader(BufferReader(next_bytes))
+ if self._arrow_schema is None:
+ self._arrow_schema = stream_reader.schema
+ yield from stream_reader
+
+ def _empty_batch_reader(self):
+ import pyarrow as pa
+ schema = pa.schema([])
+ empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
+ empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
+ return empty_reader
class BatchWriteBuilder(write_builder.BatchWriteBuilder):
- def __init__(self, j_batch_write_builder):
+ def __init__(self, j_batch_write_builder, j_row_type):
self._j_batch_write_builder = j_batch_write_builder
+ self._j_row_type = j_row_type
def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder':
self._j_batch_write_builder.withOverwrite(static_partition)
@@ -126,7 +163,7 @@ class BatchWriteBuilder(write_builder.BatchWriteBuilder):
def new_write(self) -> 'BatchTableWrite':
j_batch_table_write = self._j_batch_write_builder.newWrite()
- return BatchTableWrite(j_batch_table_write)
+ return BatchTableWrite(j_batch_table_write, self._j_row_type)
def new_commit(self) -> 'BatchTableCommit':
j_batch_table_commit = self._j_batch_write_builder.newCommit()
@@ -135,17 +172,27 @@ class BatchWriteBuilder(write_builder.BatchWriteBuilder):
class BatchTableWrite(table_write.BatchTableWrite):
- def __init__(self, j_batch_table_write):
+ def __init__(self, j_batch_table_write, j_row_type):
self._j_batch_table_write = j_batch_table_write
+ self._j_bytes_writer =
get_gateway().jvm.InvocationUtil.createBytesWriter(
+ j_batch_table_write, j_row_type)
def write(self, record_batch: RecordBatch):
- # TODO
- pass
+ stream = BufferOutputStream()
+ with RecordBatchStreamWriter(stream, record_batch.schema) as writer:
+ writer.write(record_batch)
+ writer.close()
+ arrow_bytes = stream.getvalue().to_pybytes()
+ self._j_bytes_writer.write(arrow_bytes)
def prepare_commit(self) -> List['CommitMessage']:
j_commit_messages = self._j_batch_table_write.prepareCommit()
return list(map(lambda cm: CommitMessage(cm), j_commit_messages))
+ def close(self):
+ self._j_batch_table_write.close()
+ self._j_bytes_writer.close()
+
class CommitMessage(commit_message.CommitMessage):
@@ -164,3 +211,6 @@ class BatchTableCommit(table_commit.BatchTableCommit):
def commit(self, commit_messages: List[CommitMessage]):
j_commit_messages = list(map(lambda cm: cm.to_j_commit_message(),
commit_messages))
self._j_batch_table_commit.commit(j_commit_messages)
+
+ def close(self):
+ self._j_batch_table_commit.close()
diff --git a/java_based_implementation/java_gateway.py
b/java_based_implementation/java_gateway.py
index 1618b9c..64f40b4 100644
--- a/java_based_implementation/java_gateway.py
+++ b/java_based_implementation/java_gateway.py
@@ -101,9 +101,14 @@ def launch_gateway():
return gateway
-# TODO: import more
def import_paimon_view(gateway):
java_import(gateway.jvm, "org.apache.paimon.table.*")
+ java_import(gateway.jvm, "org.apache.paimon.options.Options")
+ java_import(gateway.jvm, "org.apache.paimon.catalog.*")
+ java_import(gateway.jvm, "org.apache.paimon.schema.Schema*")
+ java_import(gateway.jvm, 'org.apache.paimon.types.*')
+ java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil')
+ java_import(gateway.jvm, "org.apache.paimon.data.*")
class Watchdog(object):
diff --git a/java_based_implementation/paimon-python-java-bridge/pom.xml
b/java_based_implementation/paimon-python-java-bridge/pom.xml
index 9618c85..5eb26e1 100644
--- a/java_based_implementation/paimon-python-java-bridge/pom.xml
+++ b/java_based_implementation/paimon-python-java-bridge/pom.xml
@@ -33,8 +33,10 @@
<flink.shaded.hadoop.version>2.8.3-10.0</flink.shaded.hadoop.version>
<py4j.version>0.10.9.7</py4j.version>
<slf4j.version>1.7.32</slf4j.version>
+ <log4j.version>2.17.1</log4j.version>
<spotless.version>2.13.0</spotless.version>
<spotless.delimiter>package</spotless.delimiter>
+ <arrow.version>14.0.0</arrow.version>
</properties>
<dependencies>
@@ -47,18 +49,48 @@
<version>${paimon.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.paimon</groupId>
+ <artifactId>paimon-arrow</artifactId>
+ <version>${paimon.version}</version>
+ </dependency>
+
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.logging.log4j</groupId>
+ <artifactId>log4j-1.2-api</artifactId>
+ <version>${log4j.version}</version>
+ </dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-shaded-hadoop-2-uber</artifactId>
<version>${flink.shaded.hadoop.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-vector</artifactId>
+ <version>${arrow.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-unsafe</artifactId>
+ <version>${arrow.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-c-data</artifactId>
+ <version>${arrow.version}</version>
+ </dependency>
+
<!-- Python API dependencies -->
<dependency>
@@ -119,11 +151,25 @@
<artifactSet>
<includes combine.children="append">
<include>org.apache.paimon:paimon-bundle</include>
+
<include>org.apache.paimon:paimon-arrow</include>
+
<include>org.apache.arrow:arrow-vector</include>
+
<include>org.apache.arrow:arrow-memory-core</include>
+
<include>org.apache.arrow:arrow-memory-unsafe</include>
+
<include>org.apache.arrow:arrow-c-data</include>
+
<include>org.apache.arrow:arrow-format</include>
+
<include>com.google.flatbuffers:flatbuffers-java</include>
<include>org.slf4j:slf4j-api</include>
+
<include>org.apache.logging.log4j:log4j-1.2-api</include>
<include>org.apache.flink:flink-shaded-hadoop-2-uber</include>
<include>net.sf.py4j:py4j</include>
</includes>
</artifactSet>
+ <relocations>
+ <relocation>
+ <pattern>com.fasterxml.jackson</pattern>
+
<shadedPattern>org.apache.paimon.shade.jackson2.com.fasterxml.jackson</shadedPattern>
+ </relocation>
+ </relocations>
</configuration>
</execution>
</executions>
diff --git
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
new file mode 100644
index 0000000..45be1d5
--- /dev/null
+++
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.python;
+
+import org.apache.paimon.arrow.ArrowUtils;
+import org.apache.paimon.arrow.vector.ArrowFormatWriter;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.reader.RecordReaderIterator;
+import org.apache.paimon.table.source.Split;
+import org.apache.paimon.table.source.TableRead;
+import org.apache.paimon.types.RowType;
+
+import org.apache.arrow.vector.VectorSchemaRoot;
+
+import javax.annotation.Nullable;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+
+/** Read Arrow bytes from split. */
+public class BytesReader {
+
+ private static final int DEFAULT_WRITE_BATCH_SIZE = 2048;
+
+ private final TableRead tableRead;
+ private final ArrowFormatWriter arrowFormatWriter;
+
+ private RecordReaderIterator<InternalRow> iterator;
+ private InternalRow nextRow;
+
+ public BytesReader(TableRead tableRead, RowType rowType) {
+ this.tableRead = tableRead;
+ this.arrowFormatWriter = new ArrowFormatWriter(rowType,
DEFAULT_WRITE_BATCH_SIZE);
+ }
+
+ public void setSplit(Split split) throws IOException {
+ RecordReader<InternalRow> recordReader = tableRead.createReader(split);
+ iterator = new RecordReaderIterator<InternalRow>(recordReader);
+ nextRow();
+ }
+
+ @Nullable
+ public byte[] next() throws Exception {
+ if (nextRow == null) {
+ return null;
+ }
+
+ int rowCount = 0;
+ while (nextRow != null && arrowFormatWriter.write(nextRow)) {
+ nextRow();
+ rowCount++;
+ }
+
+ VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+ vsr.setRowCount(rowCount);
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ ArrowUtils.serializeToIpc(vsr, out);
+ if (nextRow == null) {
+ // close resource
+ arrowFormatWriter.close();
+ iterator.close();
+ }
+ return out.toByteArray();
+ }
+
+ private void nextRow() {
+ if (iterator.hasNext()) {
+ nextRow = iterator.next();
+ } else {
+ nextRow = null;
+ }
+ }
+}
diff --git
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
new file mode 100644
index 0000000..1f4ea63
--- /dev/null
+++
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.python;
+
+import org.apache.paimon.arrow.reader.ArrowBatchReader;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.table.sink.TableWrite;
+import org.apache.paimon.types.RowType;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowStreamReader;
+
+import java.io.ByteArrayInputStream;
+
+/** Write Arrow bytes to Paimon. */
+public class BytesWriter {
+
+ private final TableWrite tableWrite;
+ private final ArrowBatchReader arrowBatchReader;
+ private final BufferAllocator allocator;
+
+ public BytesWriter(TableWrite tableWrite, RowType rowType) {
+ this.tableWrite = tableWrite;
+ this.arrowBatchReader = new ArrowBatchReader(rowType);
+ this.allocator = new RootAllocator();
+ }
+
+ public void write(byte[] bytes) throws Exception {
+ ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+ ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais,
allocator);
+ VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
+ while (arrowStreamReader.loadNextBatch()) {
+ Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
+ for (InternalRow row : rows) {
+ tableWrite.write(row);
+ }
+ }
+ arrowStreamReader.close();
+ }
+
+ public void close() {
+ allocator.close();
+ }
+}
diff --git
a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/InvocationUtil.java
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/InvocationUtil.java
new file mode 100644
index 0000000..f29773a
--- /dev/null
+++
b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/InvocationUtil.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.python;
+
+import org.apache.paimon.table.Table;
+import org.apache.paimon.table.sink.BatchWriteBuilder;
+import org.apache.paimon.table.sink.TableWrite;
+import org.apache.paimon.table.source.ReadBuilder;
+import org.apache.paimon.table.source.TableRead;
+import org.apache.paimon.types.RowType;
+
+/**
+ * Call some methods in Python directly will raise py4j.Py4JException: Method
method([]) does not
+ * exist. This util is a workaround.
+ */
+public class InvocationUtil {
+
+ public static BatchWriteBuilder getBatchWriteBuilder(Table table) {
+ return table.newBatchWriteBuilder();
+ }
+
+ public static ReadBuilder getReadBuilder(Table table) {
+ return table.newReadBuilder();
+ }
+
+ public static BytesReader createBytesReader(TableRead tableRead, RowType
rowType) {
+ return new BytesReader(tableRead, rowType);
+ }
+
+ public static BytesWriter createBytesWriter(TableWrite tableWrite, RowType
rowType) {
+ return new BytesWriter(tableWrite, rowType);
+ }
+}
diff --git a/java_based_implementation/tests/test_table_scan.py
b/java_based_implementation/tests/test_table_scan.py
deleted file mode 100644
index 2e9a939..0000000
--- a/java_based_implementation/tests/test_table_scan.py
+++ /dev/null
@@ -1,25 +0,0 @@
-################################################################################
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-################################################################################
-
-import unittest
-
-
-class TableScanTest(unittest.TestCase):
-
- def test_splits_size(self):
- pass
diff --git a/java_based_implementation/tests/test_write_and_read.py
b/java_based_implementation/tests/test_write_and_read.py
new file mode 100644
index 0000000..2833514
--- /dev/null
+++ b/java_based_implementation/tests/test_write_and_read.py
@@ -0,0 +1,138 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+import tempfile
+import unittest
+import pandas as pd
+import pyarrow as pa
+
+from java_based_implementation.api_impl import Catalog, Table
+from java_based_implementation.java_gateway import get_gateway
+from java_based_implementation.tests.utils import set_bridge_jar,
create_simple_table
+from java_based_implementation.util import constants, java_utils
+
+
+class TableWriteReadTest(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ classpath = set_bridge_jar()
+ os.environ[constants.PYPAIMON_JAVA_CLASSPATH] = classpath
+ cls.warehouse = tempfile.mkdtemp()
+
+ def testReadEmptyAppendTable(self):
+ create_simple_table(self.warehouse, 'default', 'empty_append_table',
False)
+ catalog = Catalog.create({'warehouse': self.warehouse})
+ table = catalog.get_table('default.empty_append_table')
+
+ # read data
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ splits = table_scan.plan().splits()
+
+ self.assertTrue(len(splits) == 0)
+
+ def testReadEmptyPkTable(self):
+ create_simple_table(self.warehouse, 'default', 'empty_pk_table', True)
+ gateway = get_gateway()
+ j_catalog_context = java_utils.to_j_catalog_context({'warehouse':
self.warehouse})
+ j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context)
+ j_identifier =
gateway.jvm.Identifier.fromString('default.empty_pk_table')
+ j_table = j_catalog.getTable(j_identifier)
+ j_write_builder =
gateway.jvm.InvocationUtil.getBatchWriteBuilder(j_table)
+
+ # first commit
+ generic_row = gateway.jvm.GenericRow(gateway.jvm.RowKind.INSERT, 2)
+ generic_row.setField(0, 1)
+ generic_row.setField(1, gateway.jvm.BinaryString.fromString('a'))
+ table_write = j_write_builder.newWrite()
+ table_write.write(generic_row)
+ table_commit = j_write_builder.newCommit()
+ table_commit.commit(table_write.prepareCommit())
+ table_write.close()
+ table_commit.close()
+
+ # second commit
+ generic_row = gateway.jvm.GenericRow(gateway.jvm.RowKind.DELETE, 2)
+ generic_row.setField(0, 1)
+ generic_row.setField(1, gateway.jvm.BinaryString.fromString('a'))
+ table_write = j_write_builder.newWrite()
+ table_write.write(generic_row)
+ table_commit = j_write_builder.newCommit()
+ table_commit.commit(table_write.prepareCommit())
+ table_write.close()
+ table_commit.close()
+
+ # read data
+ table = Table(j_table)
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+
+ data_frames = [
+ batch.to_pandas()
+ for split in splits
+ for batch in table_read.create_reader(split)
+ ]
+ result = pd.concat(data_frames)
+ self.assertEqual(result.shape, (0, 0))
+
+ def testWriteReadAppendTable(self):
+ create_simple_table(self.warehouse, 'default', 'simple_append_table',
False)
+
+ catalog = Catalog.create({'warehouse': self.warehouse})
+ table = catalog.get_table('default.simple_append_table')
+
+ # prepare data
+ data = {
+ 'f0': [1, 2, 3],
+ 'f1': ['a', 'b', 'c'],
+ }
+ df = pd.DataFrame(data)
+ df['f0'] = df['f0'].astype('int32')
+ record_batch = pa.RecordBatch.from_pandas(df)
+
+ # write and commit data
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ table_write.write(record_batch)
+ commit_messages = table_write.prepare_commit()
+ table_commit.commit(commit_messages)
+
+ table_write.close()
+ table_commit.close()
+
+ # read data
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+
+ data_frames = [
+ batch.to_pandas()
+ for split in splits
+ for batch in table_read.create_reader(split)
+ ]
+ result = pd.concat(data_frames)
+
+ # check data
+ pd.testing.assert_frame_equal(result, df)
diff --git a/java_based_implementation/tests/utils.py
b/java_based_implementation/tests/utils.py
index 05573f2..86276f2 100644
--- a/java_based_implementation/tests/utils.py
+++ b/java_based_implementation/tests/utils.py
@@ -21,9 +21,15 @@ import shutil
import subprocess
import tempfile
+from java_based_implementation.java_gateway import get_gateway
+from java_based_implementation.util.java_utils import to_j_catalog_context
+
def set_bridge_jar() -> str:
- java_module = '../paimon-python-java-bridge'
+ current_file_path = os.path.abspath(__file__)
+ current_dir = os.path.dirname(current_file_path)
+ parent_dir = os.path.dirname(current_dir)
+ java_module = os.path.join(parent_dir, 'paimon-python-java-bridge')
# build paimon-python-java-bridge
subprocess.run(
["mvn", "clean", "package"],
@@ -37,3 +43,25 @@ def set_bridge_jar() -> str:
temp_dir = tempfile.mkdtemp()
shutil.move(jar_file, temp_dir)
return os.path.join(temp_dir, jar_name)
+
+
+def create_simple_table(warehouse, database, table_name, has_pk):
+ gateway = get_gateway()
+
+ j_catalog_context = to_j_catalog_context({'warehouse': warehouse})
+ j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context)
+
+ j_schema_builder = (
+ gateway.jvm.Schema.newBuilder()
+ .column('f0', gateway.jvm.DataTypes.INT())
+ .column('f1', gateway.jvm.DataTypes.STRING())
+ .option('bucket', '1')
+ .option('bucket-key', 'f0')
+ )
+ if has_pk:
+ j_schema_builder.primaryKey(['f0'])
+ j_schema = j_schema_builder.build()
+
+ j_catalog.createDatabase(database, True)
+ j_identifier = gateway.jvm.Identifier(database, table_name)
+ j_catalog.createTable(j_identifier, j_schema, False)
diff --git a/paimon_python_api/table_commit.py
b/paimon_python_api/table_commit.py
index 8040c45..517b66e 100644
--- a/paimon_python_api/table_commit.py
+++ b/paimon_python_api/table_commit.py
@@ -30,3 +30,7 @@ class BatchTableCommit(ABC):
Commit the commit messages to generate snapshots. One commit may
generate
up to two snapshots, one for adding new files and the other for
compaction.
"""
+
+ @abstractmethod
+ def close(self):
+ """Close this resource."""
diff --git a/paimon_python_api/table_write.py b/paimon_python_api/table_write.py
index a878c61..d0052a1 100644
--- a/paimon_python_api/table_write.py
+++ b/paimon_python_api/table_write.py
@@ -32,3 +32,7 @@ class BatchTableWrite(ABC):
@abstractmethod
def prepare_commit(self) -> List[CommitMessage]:
"""Prepare commit message for TableCommit. Collect incremental files
for this writer."""
+
+ @abstractmethod
+ def close(self):
+ """Close this resource."""