This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git
The following commit(s) were added to refs/heads/main by this push:
new f09dc58 Refactor ReadBuilder#with_projection to accept field names
for better using (#27)
f09dc58 is described below
commit f09dc5887fdcab90b7bcac361e54414fab089fc9
Author: yuzelin <[email protected]>
AuthorDate: Mon Nov 25 20:40:20 2024 +0800
Refactor ReadBuilder#with_projection to accept field names for better using
(#27)
---
paimon_python_api/read_builder.py | 2 +-
paimon_python_java/pypaimon.py | 43 ++++++++---------
paimon_python_java/tests/test_write_and_read.py | 62 +++++++++++++++++++++++++
paimon_python_java/util/java_utils.py | 9 ++++
4 files changed, 92 insertions(+), 24 deletions(-)
diff --git a/paimon_python_api/read_builder.py
b/paimon_python_api/read_builder.py
index ad5e6d6..a031a05 100644
--- a/paimon_python_api/read_builder.py
+++ b/paimon_python_api/read_builder.py
@@ -32,7 +32,7 @@ class ReadBuilder(ABC):
"""
@abstractmethod
- def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
+ def with_projection(self, projection: List[str]) -> 'ReadBuilder':
"""Push nested projection."""
@abstractmethod
diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py
index 16c7a69..b884fa4 100644
--- a/paimon_python_java/pypaimon.py
+++ b/paimon_python_java/pypaimon.py
@@ -61,37 +61,36 @@ class Table(table.Table):
def __init__(self, j_table, catalog_options: dict):
self._j_table = j_table
self._catalog_options = catalog_options
- # init arrow schema
- schema_bytes =
get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType())
- schema_reader =
pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
- self._arrow_schema = schema_reader.schema
- schema_reader.close()
def new_read_builder(self) -> 'ReadBuilder':
j_read_builder =
get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
- return ReadBuilder(
- j_read_builder, self._j_table.rowType(), self._catalog_options,
self._arrow_schema)
+ return ReadBuilder(j_read_builder, self._j_table.rowType(),
self._catalog_options)
def new_batch_write_builder(self) -> 'BatchWriteBuilder':
java_utils.check_batch_write(self._j_table)
j_batch_write_builder =
get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
- return BatchWriteBuilder(j_batch_write_builder,
self._j_table.rowType(), self._arrow_schema)
+ return BatchWriteBuilder(j_batch_write_builder)
class ReadBuilder(read_builder.ReadBuilder):
- def __init__(self, j_read_builder, j_row_type, catalog_options: dict,
arrow_schema: pa.Schema):
+ def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
self._j_read_builder = j_read_builder
self._j_row_type = j_row_type
self._catalog_options = catalog_options
- self._arrow_schema = arrow_schema
def with_filter(self, predicate: 'Predicate'):
self._j_read_builder.withFilter(predicate.to_j_predicate())
return self
- def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
- self._j_read_builder.withProjection(projection)
+ def with_projection(self, projection: List[str]) -> 'ReadBuilder':
+ field_names = list(map(lambda field: field.name(),
self._j_row_type.getFields()))
+ int_projection = list(map(lambda p: field_names.index(p), projection))
+ gateway = get_gateway()
+ int_projection_arr = gateway.new_array(gateway.jvm.int,
len(projection))
+ for i in range(len(projection)):
+ int_projection_arr[i] = int_projection[i]
+ self._j_read_builder.withProjection(int_projection_arr)
return self
def with_limit(self, limit: int) -> 'ReadBuilder':
@@ -104,7 +103,7 @@ class ReadBuilder(read_builder.ReadBuilder):
def new_read(self) -> 'TableRead':
j_table_read = self._j_read_builder.newRead().executeFilter()
- return TableRead(j_table_read, self._j_row_type,
self._catalog_options, self._arrow_schema)
+ return TableRead(j_table_read, self._j_read_builder.readType(),
self._catalog_options)
def new_predicate_builder(self) -> 'PredicateBuilder':
return PredicateBuilder(self._j_row_type)
@@ -141,12 +140,12 @@ class Split(split.Split):
class TableRead(table_read.TableRead):
- def __init__(self, j_table_read, j_row_type, catalog_options,
arrow_schema):
+ def __init__(self, j_table_read, j_read_type, catalog_options):
self._j_table_read = j_table_read
- self._j_row_type = j_row_type
+ self._j_read_type = j_read_type
self._catalog_options = catalog_options
self._j_bytes_reader = None
- self._arrow_schema = arrow_schema
+ self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
def to_arrow(self, splits):
record_batch_reader = self.to_arrow_batch_reader(splits)
@@ -174,7 +173,7 @@ class TableRead(table_read.TableRead):
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
self._j_bytes_reader =
get_gateway().jvm.InvocationUtil.createParallelBytesReader(
- self._j_table_read, self._j_row_type, max_workers)
+ self._j_table_read, self._j_read_type, max_workers)
def _batch_generator(self) -> Iterator[pa.RecordBatch]:
while True:
@@ -188,10 +187,8 @@ class TableRead(table_read.TableRead):
class BatchWriteBuilder(write_builder.BatchWriteBuilder):
- def __init__(self, j_batch_write_builder, j_row_type, arrow_schema:
pa.Schema):
+ def __init__(self, j_batch_write_builder):
self._j_batch_write_builder = j_batch_write_builder
- self._j_row_type = j_row_type
- self._arrow_schema = arrow_schema
def overwrite(self, static_partition: Optional[dict] = None) ->
'BatchWriteBuilder':
if static_partition is None:
@@ -201,7 +198,7 @@ class BatchWriteBuilder(write_builder.BatchWriteBuilder):
def new_write(self) -> 'BatchTableWrite':
j_batch_table_write = self._j_batch_write_builder.newWrite()
- return BatchTableWrite(j_batch_table_write, self._j_row_type,
self._arrow_schema)
+ return BatchTableWrite(j_batch_table_write,
self._j_batch_write_builder.rowType())
def new_commit(self) -> 'BatchTableCommit':
j_batch_table_commit = self._j_batch_write_builder.newCommit()
@@ -210,11 +207,11 @@ class BatchWriteBuilder(write_builder.BatchWriteBuilder):
class BatchTableWrite(table_write.BatchTableWrite):
- def __init__(self, j_batch_table_write, j_row_type, arrow_schema:
pa.Schema):
+ def __init__(self, j_batch_table_write, j_row_type):
self._j_batch_table_write = j_batch_table_write
self._j_bytes_writer =
get_gateway().jvm.InvocationUtil.createBytesWriter(
j_batch_table_write, j_row_type)
- self._arrow_schema = arrow_schema
+ self._arrow_schema = java_utils.to_arrow_schema(j_row_type)
def write_arrow(self, table):
for record_batch in table.to_reader():
diff --git a/paimon_python_java/tests/test_write_and_read.py
b/paimon_python_java/tests/test_write_and_read.py
index b468e9f..337b9f5 100644
--- a/paimon_python_java/tests/test_write_and_read.py
+++ b/paimon_python_java/tests/test_write_and_read.py
@@ -445,3 +445,65 @@ class TableWriteReadTest(unittest.TestCase):
df['f0'] = df['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df.reset_index(drop=True), df.reset_index(drop=True))
+
+ def testProjection(self):
+ pa_schema = pa.schema([
+ ('f0', pa.int64()),
+ ('f1', pa.string()),
+ ('f2', pa.bool_()),
+ ('f3', pa.string())
+ ])
+ schema = Schema(pa_schema)
+ self.catalog.create_table('default.test_projection', schema, False)
+ table = self.catalog.get_table('default.test_projection')
+
+ # prepare data
+ data = {
+ 'f0': [1, 2, 3],
+ 'f1': ['a', 'b', 'c'],
+ 'f2': [True, True, False],
+ 'f3': ['A', 'B', 'C']
+ }
+ df = pd.DataFrame(data)
+
+ # write and commit data
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ table_write.write_pandas(df)
+ commit_messages = table_write.prepare_commit()
+ table_commit.commit(commit_messages)
+
+ table_write.close()
+ table_commit.close()
+
+ # case 1: read empty
+ read_builder = table.new_read_builder().with_projection([])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ result1 = table_read.to_pandas(splits)
+ self.assertTrue(result1.empty)
+
+ # case 2: read fully
+ read_builder = table.new_read_builder().with_projection(['f0', 'f1',
'f2', 'f3'])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ result2 = table_read.to_pandas(splits)
+ pd.testing.assert_frame_equal(
+ result2.reset_index(drop=True), df.reset_index(drop=True))
+
+ # case 3: read partially
+ read_builder = table.new_read_builder().with_projection(['f3', 'f2'])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ result3 = table_read.to_pandas(splits)
+ expected_df = pd.DataFrame({
+ 'f3': ['A', 'B', 'C'],
+ 'f2': [True, True, False]
+ })
+ pd.testing.assert_frame_equal(
+ result3.reset_index(drop=True), expected_df.reset_index(drop=True))
diff --git a/paimon_python_java/util/java_utils.py
b/paimon_python_java/util/java_utils.py
index 8c4f276..ce0404a 100644
--- a/paimon_python_java/util/java_utils.py
+++ b/paimon_python_java/util/java_utils.py
@@ -91,3 +91,12 @@ def _to_j_type(name, pa_type):
return jvm.DataTypes.STRING()
else:
raise ValueError(f'Found unsupported data type {str(pa_type)} for
field {name}.')
+
+
+def to_arrow_schema(j_row_type):
+ # init arrow schema
+ schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_row_type)
+ schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
+ arrow_schema = schema_reader.schema
+ schema_reader.close()
+ return arrow_schema