This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new ae5635a943 [python] Fix read after data evolution updating by shard
(#7157)
ae5635a943 is described below
commit ae5635a94375ab4416edc01fcd0ed529224a26fa
Author: XiaoHongbo <[email protected]>
AuthorDate: Sun Mar 1 21:52:50 2026 +0800
[python] Fix read after data evolution updating by shard (#7157)
---
.../pypaimon/read/reader/format_pyarrow_reader.py | 33 ++-
paimon-python/pypaimon/read/split_read.py | 7 +-
.../pypaimon/tests/data_evolution_test.py | 130 +++++++++-
.../pypaimon/tests/shard_table_updator_test.py | 275 ++++++++++++++++++++-
4 files changed, 429 insertions(+), 16 deletions(-)
diff --git a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
index dd5330227d..e9c9efd917 100644
--- a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
+++ b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
@@ -24,6 +24,8 @@ from pyarrow import RecordBatch
from pypaimon.common.file_io import FileIO
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
+from pypaimon.schema.data_types import DataField, PyarrowFieldParser
+from pypaimon.table.special_fields import SpecialFields
class FormatPyArrowReader(RecordBatchReader):
@@ -32,16 +34,18 @@ class FormatPyArrowReader(RecordBatchReader):
and filters it based on the provided predicate and projection.
"""
- def __init__(self, file_io: FileIO, file_format: str, file_path: str,
read_fields: List[str],
+ def __init__(self, file_io: FileIO, file_format: str, file_path: str,
+ read_fields: List[DataField],
push_down_predicate: Any, batch_size: int = 1024):
file_path_for_pyarrow = file_io.to_filesystem_path(file_path)
self.dataset = ds.dataset(file_path_for_pyarrow, format=file_format,
filesystem=file_io.filesystem)
self.read_fields = read_fields
+ self._read_field_names = [f.name for f in read_fields]
# Identify which fields exist in the file and which are missing
file_schema_names = set(self.dataset.schema.names)
- self.existing_fields = [field for field in read_fields if field in
file_schema_names]
- self.missing_fields = [field for field in read_fields if field not in
file_schema_names]
+ self.existing_fields = [f.name for f in read_fields if f.name in
file_schema_names]
+ self.missing_fields = [f.name for f in read_fields if f.name not in
file_schema_names]
# Only pass existing fields to PyArrow scanner to avoid errors
self.reader = self.dataset.scanner(
@@ -50,6 +54,10 @@ class FormatPyArrowReader(RecordBatchReader):
batch_size=batch_size
).to_reader()
+ self._output_schema = (
+ PyarrowFieldParser.from_paimon_schema(read_fields) if read_fields
else None
+ )
+
def read_arrow_batch(self) -> Optional[RecordBatch]:
try:
batch = self.reader.read_next_batch()
@@ -57,13 +65,22 @@ class FormatPyArrowReader(RecordBatchReader):
if not self.missing_fields:
return batch
- # Create columns for missing fields with null values
- missing_columns = [pa.nulls(batch.num_rows, type=pa.null()) for _
in self.missing_fields]
+ def _type_for_missing(name: str) -> pa.DataType:
+ if self._output_schema is not None:
+ idx = self._output_schema.get_field_index(name)
+ if idx >= 0:
+ return self._output_schema.field(idx).type
+ return pa.null()
+
+ missing_columns = [
+ pa.nulls(batch.num_rows, type=_type_for_missing(name))
+ for name in self.missing_fields
+ ]
# Reconstruct the batch with all fields in the correct order
all_columns = []
out_fields = []
- for field_name in self.read_fields:
+ for field_name in self._read_field_names:
if field_name in self.existing_fields:
# Get the column from the existing batch
column_idx = self.existing_fields.index(field_name)
@@ -72,8 +89,10 @@ class FormatPyArrowReader(RecordBatchReader):
else:
# Get the column from missing fields
column_idx = self.missing_fields.index(field_name)
+ col_type = _type_for_missing(field_name)
all_columns.append(missing_columns[column_idx])
- out_fields.append(pa.field(field_name, pa.null(),
nullable=True))
+ nullable = not SpecialFields.is_system_field(field_name)
+ out_fields.append(pa.field(field_name, col_type,
nullable=nullable))
# Create a new RecordBatch with all columns
return pa.RecordBatch.from_arrays(all_columns,
schema=pa.schema(out_fields))
diff --git a/paimon-python/pypaimon/read/split_read.py
b/paimon-python/pypaimon/read/split_read.py
index d76a71682b..6034723647 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -137,8 +137,11 @@ class SplitRead(ABC):
format_reader = FormatLanceReader(self.table.file_io, file_path,
read_file_fields,
read_arrow_predicate,
batch_size=batch_size)
elif file_format == CoreOptions.FILE_FORMAT_PARQUET or file_format ==
CoreOptions.FILE_FORMAT_ORC:
- format_reader = FormatPyArrowReader(self.table.file_io,
file_format, file_path,
- read_file_fields,
read_arrow_predicate, batch_size=batch_size)
+ name_to_field = {f.name: f for f in self.read_fields}
+ ordered_read_fields = [name_to_field[n] for n in read_file_fields
if n in name_to_field]
+ format_reader = FormatPyArrowReader(
+ self.table.file_io, file_format, file_path,
+ ordered_read_fields, read_arrow_predicate,
batch_size=batch_size)
else:
raise ValueError(f"Unexpected file format: {file_format}")
diff --git a/paimon-python/pypaimon/tests/data_evolution_test.py
b/paimon-python/pypaimon/tests/data_evolution_test.py
index 3759bfdb46..afb12cd948 100644
--- a/paimon-python/pypaimon/tests/data_evolution_test.py
+++ b/paimon-python/pypaimon/tests/data_evolution_test.py
@@ -28,7 +28,6 @@ import pyarrow.dataset as ds
from pypaimon import CatalogFactory, Schema
from pypaimon.common.predicate import Predicate
from pypaimon.manifest.manifest_list_manager import ManifestListManager
-from pypaimon.read.read_builder import ReadBuilder
from pypaimon.snapshot.snapshot_manager import SnapshotManager
from pypaimon.table.row.offset_row import OffsetRow
@@ -141,13 +140,63 @@ class DataEvolutionTest(unittest.TestCase):
('f1', pa.int16()),
]))
self.assertEqual(actual_data, expect_data)
+ self.assertEqual(
+ len(actual_data.schema), len(expect_data.schema),
+ 'Read output column count must match schema')
+ self.assertEqual(
+ actual_data.schema.names, expect_data.schema.names,
+ 'Read output column names must match schema')
+
+ def test_partitioned_read_requested_column_missing_in_file(self):
+ pa_schema = pa.schema([('f0', pa.int32()), ('f1', pa.string()), ('dt',
pa.string())])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ partition_keys=['dt'],
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'}
+ )
+ self.catalog.create_table('default.test_partition_missing_col',
schema, False)
+ table = self.catalog.get_table('default.test_partition_missing_col')
+ wb = table.new_batch_write_builder()
- # assert manifest file meta contains min and max row id
+ tw1 = wb.new_write()
+ tc1 = wb.new_commit()
+ tw1.write_arrow(pa.Table.from_pydict(
+ {'f0': [1, 2], 'f1': ['a', 'b'], 'dt': ['p1', 'p1']},
+ schema=pa_schema
+ ))
+ tc1.commit(tw1.prepare_commit())
+ tw1.close()
+ tc1.close()
+
+ tw2 = wb.new_write().with_write_type(['f0', 'dt'])
+ tc2 = wb.new_commit()
+ # Row key extractor uses table column indices; pass table-ordered data
with null for f1
+ tw2.write_arrow(pa.Table.from_pydict(
+ {'f0': [3, 4], 'f1': [None, None], 'dt': ['p1', 'p1']},
+ schema=pa_schema
+ ))
+ tc2.commit(tw2.prepare_commit())
+ tw2.close()
+ tc2.close()
+
+ actual =
table.new_read_builder().new_read().to_arrow(table.new_read_builder().new_scan().plan().splits())
+ self.assertEqual(len(actual.schema), 3, 'Must have f0, f1, dt (no
silent drop when f1 missing in file)')
+ self.assertEqual(actual.schema.names, ['f0', 'f1', 'dt'])
+ self.assertEqual(actual.num_rows, 4)
+ f1_col = actual.column('f1')
+ self.assertEqual(f1_col[0].as_py(), 'a')
+ self.assertEqual(f1_col[1].as_py(), 'b')
+ self.assertIsNone(f1_col[2].as_py())
+ self.assertIsNone(f1_col[3].as_py())
+
+ # Assert manifest file meta contains min and max row id
manifest_list_manager = ManifestListManager(table)
snapshot_manager = SnapshotManager(table)
- manifest =
manifest_list_manager.read(snapshot_manager.get_latest_snapshot().delta_manifest_list)[0]
- self.assertEqual(0, manifest.min_row_id)
- self.assertEqual(1, manifest.max_row_id)
+ all_manifests =
manifest_list_manager.read_all(snapshot_manager.get_latest_snapshot())
+ first_commit = next((m for m in all_manifests if m.min_row_id == 0 and
m.max_row_id == 1), None)
+ self.assertIsNotNone(first_commit, "Should have a manifest with
min_row_id=0, max_row_id=1")
+ second_commit = next((m for m in all_manifests if m.min_row_id == 2
and m.max_row_id == 3), None)
+ self.assertIsNotNone(second_commit, "Should have a manifest with
min_row_id=2, max_row_id=3")
def test_merge_reader(self):
from pypaimon.read.reader.concat_batch_reader import
MergeAllBatchReader
@@ -280,6 +329,14 @@ class DataEvolutionTest(unittest.TestCase):
[2, 1001, 2001],
"with_slice(1, 4) should return id in (2, 1001, 2001). Got ids=%s"
% ids,
)
+ scan_oob = rb.new_scan().with_slice(10, 12)
+ splits_oob = scan_oob.plan().splits()
+ result_oob = rb.new_read().to_pandas(splits_oob)
+ self.assertEqual(
+ len(result_oob),
+ 0,
+ "with_slice(10, 12) on 6 rows should return 0 rows (out of
bounds), got %d" % len(result_oob),
+ )
# Out-of-bounds slice: 6 rows total, slice(10, 12) should return 0 rows
scan_oob = rb.new_scan().with_slice(10, 12)
@@ -439,6 +496,8 @@ class DataEvolutionTest(unittest.TestCase):
'f2': ['b'] * 100 + ['y'] + ['d'],
}, schema=simple_pa_schema)
self.assertEqual(actual, expect)
+ self.assertEqual(len(actual.schema), len(expect.schema), 'Merge read
output column count must match schema')
+ self.assertEqual(actual.schema.names, expect.schema.names, 'Merge read
output column names must match schema')
def test_disorder_cols_append(self):
simple_pa_schema = pa.schema([
@@ -1175,6 +1234,7 @@ class DataEvolutionTest(unittest.TestCase):
pa.field('_SEQUENCE_NUMBER', pa.int64(), nullable=False),
]))
self.assertEqual(actual_data, expect_data)
+ self.assertEqual(len(actual_data.schema), len(expect_data.schema),
'Read output column count must match schema')
# write 2
table_write = write_builder.new_write().with_write_type(['f0'])
@@ -1210,6 +1270,66 @@ class DataEvolutionTest(unittest.TestCase):
pa.field('_SEQUENCE_NUMBER', pa.int64(), nullable=False),
]))
self.assertEqual(actual_data, expect_data)
+ self.assertEqual(len(actual_data.schema), len(expect_data.schema),
'Read output column count must match schema')
+
+ def test_with_blob(self):
+ from pypaimon.table.row.blob import BlobDescriptor
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('picture', pa.large_binary()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ options={
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true',
+ 'blob-as-descriptor': 'true',
+ },
+ )
+ self.catalog.create_table('default.test_with_blob', schema, False)
+ table = self.catalog.get_table('default.test_with_blob')
+
+ blob_path = os.path.join(self.tempdir, 'blob_ev')
+ with open(blob_path, 'wb') as f:
+ f.write(b'x')
+ descriptor = BlobDescriptor(blob_path, 0, 1)
+
+ wb = table.new_batch_write_builder()
+ tw = wb.new_write()
+ tc = wb.new_commit()
+ tw.write_arrow(pa.Table.from_pydict(
+ {'id': [1], 'picture': [descriptor.serialize()]},
+ schema=pa_schema,
+ ))
+ cmts = tw.prepare_commit()
+ if cmts and cmts[0].new_files:
+ for nf in cmts[0].new_files:
+ nf.first_row_id = 0
+ tc.commit(cmts)
+ tw.close()
+ tc.close()
+
+ tw = wb.new_write()
+ tc = wb.new_commit()
+ tw.write_arrow(pa.Table.from_pydict(
+ {'id': [2], 'picture': [descriptor.serialize()]},
+ schema=pa_schema,
+ ))
+ cmts = tw.prepare_commit()
+ if cmts and cmts[0].new_files:
+ for nf in cmts[0].new_files:
+ nf.first_row_id = 1
+ tc.commit(cmts)
+ tw.close()
+ tc.close()
+
+ rb = table.new_read_builder()
+ rb.with_projection(['id', '_ROW_ID', 'picture', '_SEQUENCE_NUMBER'])
+ actual = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(actual.num_rows, 2)
+ self.assertEqual(actual.column('id').to_pylist(), [1, 2])
+ self.assertEqual(actual.column('_ROW_ID').to_pylist(), [0, 1])
def test_from_arrays_without_schema(self):
schema = pa.schema([
diff --git a/paimon-python/pypaimon/tests/shard_table_updator_test.py
b/paimon-python/pypaimon/tests/shard_table_updator_test.py
index 967dfbcd6e..1ff658c609 100644
--- a/paimon-python/pypaimon/tests/shard_table_updator_test.py
+++ b/paimon-python/pypaimon/tests/shard_table_updator_test.py
@@ -85,7 +85,7 @@ class ShardTableUpdatorTest(unittest.TestCase):
# Step 3: Use ShardTableUpdator to compute d = c + b - a
table_update = write_builder.new_update()
- table_update.with_read_projection(['a', 'b', 'c'])
+ table_update.with_read_projection(['a', 'b', 'c', '_ROW_ID'])
table_update.with_update_type(['d'])
shard_updator = table_update.new_shard_updator(0, 1)
@@ -98,7 +98,13 @@ class ShardTableUpdatorTest(unittest.TestCase):
a_values = batch.column('a').to_pylist()
b_values = batch.column('b').to_pylist()
c_values = batch.column('c').to_pylist()
-
+ row_id_values = batch.column('_ROW_ID').to_pylist()
+ self.assertEqual(
+ row_id_values,
+ list(range(len(a_values))),
+ '_ROW_ID should be [0, 1, 2, ...] for sequential rows',
+ )
+
d_values = [c + b - a for a, b, c in zip(a_values, b_values,
c_values)]
# Create batch with d column
@@ -321,5 +327,270 @@ class ShardTableUpdatorTest(unittest.TestCase):
self.assertEqual(actual, expected)
print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+ def test_partial_shard_update_full_read_schema_unified(self):
+ table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ('d', pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'},
+ )
+ name = self._create_unique_table_name()
+ self.catalog.create_table(name, schema, False)
+ table = self.catalog.get_table(name)
+
+ # Two commits => two files (two first_row_id ranges)
+ for start, end in [(1, 10), (10, 20)]:
+ wb = table.new_batch_write_builder()
+ tw = wb.new_write().with_write_type(['a', 'b', 'c'])
+ tc = wb.new_commit()
+ data = pa.Table.from_pydict({
+ 'a': list(range(start, end + 1)),
+ 'b': [i * 10 for i in range(start, end + 1)],
+ 'c': [i * 100 for i in range(start, end + 1)],
+ }, schema=pa.schema([
+ ('a', pa.int32()), ('b', pa.int32()), ('c', pa.int32()),
+ ]))
+ tw.write_arrow(data)
+ tc.commit(tw.prepare_commit())
+ tw.close()
+ tc.close()
+
+ # Only shard 0 runs => only first file gets d
+ wb = table.new_batch_write_builder()
+ upd = wb.new_update()
+ upd.with_read_projection(['a', 'b', 'c'])
+ upd.with_update_type(['d'])
+ shard0 = upd.new_shard_updator(0, 2)
+ reader = shard0.arrow_reader()
+ for batch in iter(reader.read_next_batch, None):
+ a_ = batch.column('a').to_pylist()
+ b_ = batch.column('b').to_pylist()
+ c_ = batch.column('c').to_pylist()
+ d_ = [c + b - a for a, b, c in zip(a_, b_, c_)]
+ shard0.update_by_arrow_batch(pa.RecordBatch.from_pydict(
+ {'d': d_}, schema=pa.schema([('d', pa.int32())]),
+ ))
+ tc = wb.new_commit()
+ tc.commit(shard0.prepare_commit())
+ tc.close()
+
+ rb = table.new_read_builder()
+ tr = rb.new_read()
+ actual = tr.to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(actual.num_rows, 21)
+ d_col = actual.column('d')
+ # First 10 rows (shard 0): d = c+b-a
+ for i in range(10):
+ self.assertEqual(d_col[i].as_py(), (i + 1) * 100 + (i + 1) * 10 -
(i + 1))
+ # Rows 10-20 (shard 1 not run): d is null
+ for i in range(10, 21):
+ self.assertIsNone(d_col[i].as_py())
+
+ def test_with_shard_read_after_partial_shard_update(self):
+ table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ('d', pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'},
+ )
+ name = self._create_unique_table_name()
+ self.catalog.create_table(name, schema, False)
+ table = self.catalog.get_table(name)
+
+ for start, end in [(1, 10), (10, 20)]:
+ wb = table.new_batch_write_builder()
+ tw = wb.new_write().with_write_type(['a', 'b', 'c'])
+ tc = wb.new_commit()
+ data = pa.Table.from_pydict({
+ 'a': list(range(start, end + 1)),
+ 'b': [i * 10 for i in range(start, end + 1)],
+ 'c': [i * 100 for i in range(start, end + 1)],
+ }, schema=pa.schema([
+ ('a', pa.int32()), ('b', pa.int32()), ('c', pa.int32()),
+ ]))
+ tw.write_arrow(data)
+ tc.commit(tw.prepare_commit())
+ tw.close()
+ tc.close()
+
+ wb = table.new_batch_write_builder()
+ upd = wb.new_update()
+ upd.with_read_projection(['a', 'b', 'c'])
+ upd.with_update_type(['d'])
+ shard0 = upd.new_shard_updator(0, 2)
+ reader = shard0.arrow_reader()
+ for batch in iter(reader.read_next_batch, None):
+ a_ = batch.column('a').to_pylist()
+ b_ = batch.column('b').to_pylist()
+ c_ = batch.column('c').to_pylist()
+ d_ = [c + b - a for a, b, c in zip(a_, b_, c_)]
+ shard0.update_by_arrow_batch(pa.RecordBatch.from_pydict(
+ {'d': d_}, schema=pa.schema([('d', pa.int32())]),
+ ))
+ tc = wb.new_commit()
+ tc.commit(shard0.prepare_commit())
+ tc.close()
+
+ rb = table.new_read_builder()
+ tr = rb.new_read()
+
+ splits_0 = rb.new_scan().with_shard(0, 2).plan().splits()
+ result_0 = tr.to_arrow(splits_0)
+ self.assertEqual(result_0.num_rows, 11)
+ d_col_0 = result_0.column('d')
+ for i in range(10):
+ self.assertEqual(
+ d_col_0[i].as_py(),
+ (i + 1) * 100 + (i + 1) * 10 - (i + 1),
+ "Shard 0 row %d: d should be c+b-a" % i,
+ )
+ self.assertIsNone(d_col_0[10].as_py(), "Shard 0 row 10: d not updated,
should be null")
+
+ splits_1 = rb.new_scan().with_shard(1, 2).plan().splits()
+ result_1 = tr.to_arrow(splits_1)
+ self.assertEqual(result_1.num_rows, 10)
+ d_col_1 = result_1.column('d')
+ for i in range(10):
+ self.assertIsNone(d_col_1[i].as_py(), "Shard 1 row %d: d should be
null" % i)
+
+ full_splits = rb.new_scan().plan().splits()
+ full_result = tr.to_arrow(full_splits)
+ self.assertEqual(
+ result_0.num_rows + result_1.num_rows,
+ full_result.num_rows,
+ "Shard 0 + Shard 1 row count should equal full scan (21)",
+ )
+
+ rb_filter = table.new_read_builder()
+ rb_filter.with_projection(['a', 'b', 'c', 'd', '_ROW_ID'])
+ pb = rb_filter.new_predicate_builder()
+ pred_row_id = pb.is_in('_ROW_ID', [0, 1, 2, 3, 4])
+ rb_filter.with_filter(pred_row_id)
+ tr_filter = rb_filter.new_read()
+ splits_row_id = rb_filter.new_scan().plan().splits()
+ result_row_id = tr_filter.to_arrow(splits_row_id)
+ self.assertEqual(result_row_id.num_rows, 5, "Filter _ROW_ID in [0..4]
should return 5 rows")
+ a_col = result_row_id.column('a')
+ d_col_r = result_row_id.column('d')
+ for i in range(5):
+ self.assertEqual(a_col[i].as_py(), i + 1)
+ self.assertEqual(
+ d_col_r[i].as_py(),
+ (i + 1) * 100 + (i + 1) * 10 - (i + 1),
+ "Filter-by-_row_id row %d: d should be c+b-a" % i,
+ )
+
+ rb_slice = table.new_read_builder()
+ tr_slice = rb_slice.new_read()
+ slice_0 = rb_slice.new_scan().with_slice(0, 10).plan().splits()
+ result_slice_0 = tr_slice.to_arrow(slice_0)
+ self.assertEqual(result_slice_0.num_rows, 10, "with_slice(0, 10)
should return 10 rows")
+ d_s0 = result_slice_0.column('d')
+ for i in range(10):
+ self.assertEqual(
+ d_s0[i].as_py(),
+ (i + 1) * 100 + (i + 1) * 10 - (i + 1),
+ "Slice [0,10) row %d: d should be c+b-a" % i,
+ )
+ slice_1 = rb_slice.new_scan().with_slice(10, 21).plan().splits()
+ result_slice_1 = tr_slice.to_arrow(slice_1)
+ self.assertEqual(result_slice_1.num_rows, 11, "with_slice(10, 21)
should return 11 rows")
+ d_s1 = result_slice_1.column('d')
+ for i in range(11):
+ self.assertIsNone(d_s1[i].as_py(), "Slice [10,21) row %d: d should
be null" % i)
+
+ cross_slice = rb_slice.new_scan().with_slice(5, 16).plan().splits()
+ result_cross = tr_slice.to_arrow(cross_slice)
+ self.assertEqual(
+ result_cross.num_rows, 11,
+ "Cross-shard with_slice(5, 16) should return 11 rows (5 from file1
+ 6 from file2)",
+ )
+ a_cross = result_cross.column('a')
+ d_cross = result_cross.column('d')
+ for i in range(5):
+ self.assertEqual(a_cross[i].as_py(), 6 + i)
+ self.assertEqual(
+ d_cross[i].as_py(),
+ (6 + i) * 100 + (6 + i) * 10 - (6 + i),
+ "Cross-shard slice row %d (from file1): d should be c+b-a" % i,
+ )
+ for i in range(5, 11):
+ self.assertEqual(a_cross[i].as_py(), 10 + (i - 5))
+ self.assertIsNone(d_cross[i].as_py(), "Cross-shard slice row %d
(from file2): d null" % i)
+
+ rb_col = table.new_read_builder()
+ rb_col.with_projection(['a', 'b', 'c', 'd'])
+ pb_col = rb_col.new_predicate_builder()
+ pred_d = pb_col.is_in('d', [109, 218]) # d = c+b-a for a=1,2
+ rb_col.with_filter(pred_d)
+ tr_col = rb_col.new_read()
+ splits_d = rb_col.new_scan().plan().splits()
+ result_d = tr_col.to_arrow(splits_d)
+ self.assertEqual(result_d.num_rows, 2, "Filter d in [109, 218] should
return 2 rows")
+ a_d = result_d.column('a')
+ d_d = result_d.column('d')
+ self.assertEqual(a_d[0].as_py(), 1)
+ self.assertEqual(d_d[0].as_py(), 109)
+ self.assertEqual(a_d[1].as_py(), 2)
+ self.assertEqual(d_d[1].as_py(), 218)
+
+ def test_read_projection(self):
+ table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'}
+ )
+ name = self._create_unique_table_name('read_proj')
+ self.catalog.create_table(name, schema, False)
+ table = self.catalog.get_table(name)
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write().with_write_type(['a', 'b',
'c'])
+ table_commit = write_builder.new_commit()
+ init_data = pa.Table.from_pydict(
+ {'a': [1, 2, 3], 'b': [10, 20, 30], 'c': [100, 200, 300]},
+ schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c',
pa.int32())])
+ )
+ table_write.write_arrow(init_data)
+ cmts = table_write.prepare_commit()
+ for cmt in cmts:
+ for nf in cmt.new_files:
+ nf.first_row_id = 0
+ table_commit.commit(cmts)
+ table_write.close()
+ table_commit.close()
+
+ table_update = write_builder.new_update()
+ table_update.with_read_projection(['a', 'b', 'c'])
+ table_update.with_update_type(['a'])
+ shard_updator = table_update.new_shard_updator(0, 1)
+ reader = shard_updator.arrow_reader()
+
+ batch = reader.read_next_batch()
+ self.assertIsNotNone(batch, "Should have at least one batch")
+ actual_columns = set(batch.schema.names)
+
+ expected_columns = {'a', 'b', 'c'}
+ self.assertEqual(
+ actual_columns,
+ expected_columns,
+ "with_read_projection(['a','b','c']) should return only a,b,c; "
+ "got %s. _ROW_ID and _SEQUENCE_NUMBER should NOT be returned when
not in projection."
+ % actual_columns
+ )
+
+
if __name__ == '__main__':
unittest.main()