This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch release-1.3 in repository https://gitbox.apache.org/repos/asf/paimon.git
commit cddedd03419ce18e2b54eee768aabc3611f22fcb Author: umi <[email protected]> AuthorDate: Thu Oct 23 14:53:19 2025 +0800 [Python] Support schema evolution read for changing column position (#6458) --- paimon-python/pypaimon/common/predicate.py | 21 +- .../pypaimon/manifest/manifest_file_manager.py | 42 ++- .../pypaimon/manifest/simple_stats_evolutions.py | 5 +- .../pypaimon/read/scanner/full_starting_scanner.py | 9 +- paimon-python/pypaimon/read/split_read.py | 53 +++- paimon-python/pypaimon/read/table_read.py | 36 ++- paimon-python/pypaimon/tests/pvfs_test.py | 3 +- .../pypaimon/tests/reader_append_only_test.py | 2 +- paimon-python/pypaimon/tests/rest/rest_server.py | 7 +- .../pypaimon/tests/schema_evolution_read_test.py | 328 +++++++++++++++++++++ 10 files changed, 454 insertions(+), 52 deletions(-) diff --git a/paimon-python/pypaimon/common/predicate.py b/paimon-python/pypaimon/common/predicate.py index 9ae2cdfce3..89c82c9de2 100644 --- a/paimon-python/pypaimon/common/predicate.py +++ b/paimon-python/pypaimon/common/predicate.py @@ -94,10 +94,10 @@ class Predicate: def to_arrow(self) -> Any: if self.method == 'and': - return reduce(lambda x, y: x & y, + return reduce(lambda x, y: (x[0] & y[0], x[1] | y[1]), [p.to_arrow() for p in self.literals]) if self.method == 'or': - return reduce(lambda x, y: x | y, + return reduce(lambda x, y: (x[0] | y[0], x[1] | y[1]), [p.to_arrow() for p in self.literals]) if self.method == 'startsWith': @@ -108,10 +108,11 @@ class Predicate: # Ensure the field is cast to string type string_field = field_ref.cast(pyarrow.string()) result = pyarrow_compute.starts_with(string_field, pattern) - return result + return result, {self.field} except Exception: # Fallback to True - return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null() + return (pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null(), + {self.field}) if self.method == 'endsWith': pattern = self.literals[0] # For PyArrow compatibility @@ -120,10 +121,11 @@ class Predicate: # Ensure the field is cast to string type string_field = field_ref.cast(pyarrow.string()) result = pyarrow_compute.ends_with(string_field, pattern) - return result + return result, {self.field} except Exception: # Fallback to True - return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null() + return (pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null(), + {self.field}) if self.method == 'contains': pattern = self.literals[0] # For PyArrow compatibility @@ -132,15 +134,16 @@ class Predicate: # Ensure the field is cast to string type string_field = field_ref.cast(pyarrow.string()) result = pyarrow_compute.match_substring(string_field, pattern) - return result + return result, {self.field} except Exception: # Fallback to True - return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null() + return (pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null(), + {self.field}) field = pyarrow_dataset.field(self.field) tester = Predicate.testers.get(self.method) if tester: - return tester.test_by_arrow(field, self.literals) + return tester.test_by_arrow(field, self.literals), {self.field} raise ValueError("Unsupported predicate method: {}".format(self.method)) diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py b/paimon-python/pypaimon/manifest/manifest_file_manager.py index b635f9e49c..927dad4674 100644 --- a/paimon-python/pypaimon/manifest/manifest_file_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py @@ -26,6 +26,7 @@ from pypaimon.manifest.schema.manifest_entry import (MANIFEST_ENTRY_SCHEMA, ManifestEntry) from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta from pypaimon.manifest.schema.simple_stats import SimpleStats +from pypaimon.schema.table_schema import TableSchema from pypaimon.table.row.generic_row import (GenericRowDeserializer, GenericRowSerializer) from pypaimon.table.row.binary_row import BinaryRow @@ -43,6 +44,7 @@ class ManifestFileManager: self.partition_keys_fields = self.table.partition_keys_fields self.primary_keys_fields = self.table.primary_keys_fields self.trimmed_primary_keys_fields = self.table.trimmed_primary_keys_fields + self.schema_cache = {} def read_entries_parallel(self, manifest_files: List[ManifestFileMeta], manifest_entry_filter=None, drop_stats=True, max_workers=8) -> List[ManifestEntry]: @@ -86,17 +88,9 @@ class ManifestFileManager: null_counts=key_dict['_NULL_COUNTS'], ) + schema_fields = self._get_schema(file_dict['_SCHEMA_ID']).fields + fields = self._get_value_stats_fields(file_dict, schema_fields) value_dict = dict(file_dict['_VALUE_STATS']) - if file_dict['_VALUE_STATS_COLS'] is None: - if file_dict['_WRITE_COLS'] is None: - fields = self.table.table_schema.fields - else: - read_fields = file_dict['_WRITE_COLS'] - fields = [self.table.field_dict[col] for col in read_fields] - elif not file_dict['_VALUE_STATS_COLS']: - fields = [] - else: - fields = [self.table.field_dict[col] for col in file_dict['_VALUE_STATS_COLS']] value_stats = SimpleStats( min_values=BinaryRow(value_dict['_MIN_VALUES'], fields), max_values=BinaryRow(value_dict['_MAX_VALUES'], fields), @@ -121,8 +115,8 @@ class ManifestFileManager: file_source=file_dict['_FILE_SOURCE'], value_stats_cols=file_dict.get('_VALUE_STATS_COLS'), external_path=file_dict.get('_EXTERNAL_PATH'), - first_row_id=file_dict['_FIRST_ROW_ID'], - write_cols=file_dict['_WRITE_COLS'], + first_row_id=file_dict['_FIRST_ROW_ID'] if '_FIRST_ROW_ID' in file_dict else None, + write_cols=file_dict['_WRITE_COLS'] if '_WRITE_COLS' in file_dict else None, ) entry = ManifestEntry( kind=record['_KIND'], @@ -138,6 +132,30 @@ class ManifestFileManager: entries.append(entry) return entries + def _get_value_stats_fields(self, file_dict: dict, schema_fields: list) -> List: + if file_dict['_VALUE_STATS_COLS'] is None: + if '_WRITE_COLS' in file_dict: + if file_dict['_WRITE_COLS'] is None: + fields = schema_fields + else: + read_fields = file_dict['_WRITE_COLS'] + fields = [self.table.field_dict[col] for col in read_fields] + else: + fields = schema_fields + elif not file_dict['_VALUE_STATS_COLS']: + fields = [] + else: + fields = [self.table.field_dict[col] for col in file_dict['_VALUE_STATS_COLS']] + return fields + + def _get_schema(self, schema_id: int) -> TableSchema: + if schema_id not in self.schema_cache: + schema = self.table.schema_manager.read_schema(schema_id) + if schema is None: + raise ValueError(f"Schema {schema_id} not found") + self.schema_cache[schema_id] = schema + return self.schema_cache[schema_id] + def write(self, file_name, entries: List[ManifestEntry]): avro_records = [] for entry in entries: diff --git a/paimon-python/pypaimon/manifest/simple_stats_evolutions.py b/paimon-python/pypaimon/manifest/simple_stats_evolutions.py index 0b99acab21..df417d595b 100644 --- a/paimon-python/pypaimon/manifest/simple_stats_evolutions.py +++ b/paimon-python/pypaimon/manifest/simple_stats_evolutions.py @@ -28,8 +28,7 @@ class SimpleStatsEvolutions: def __init__(self, schema_fields: Callable[[int], List[DataField]], table_schema_id: int): self.schema_fields = schema_fields self.table_schema_id = table_schema_id - self.table_data_fields = schema_fields(table_schema_id) - self.table_fields = None + self.table_fields = schema_fields(table_schema_id) self.evolutions: Dict[int, SimpleStatsEvolution] = {} def get_or_create(self, data_schema_id: int) -> SimpleStatsEvolution: @@ -40,8 +39,6 @@ class SimpleStatsEvolutions: if self.table_schema_id == data_schema_id: evolution = SimpleStatsEvolution(self.schema_fields(data_schema_id), None, None) else: - if self.table_fields is None: - self.table_fields = self.table_data_fields data_fields = self.schema_fields(data_schema_id) index_cast_mapping = self._create_index_cast_mapping(self.table_fields, data_fields) diff --git a/paimon-python/pypaimon/read/scanner/full_starting_scanner.py b/paimon-python/pypaimon/read/scanner/full_starting_scanner.py index cacf3ce343..44223b761a 100644 --- a/paimon-python/pypaimon/read/scanner/full_starting_scanner.py +++ b/paimon-python/pypaimon/read/scanner/full_starting_scanner.py @@ -64,13 +64,12 @@ class FullStartingScanner(StartingScanner): self.table.options.get('bucket', -1)) == BucketMode.POSTPONE_BUCKET.value else False self.data_evolution = self.table.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, 'false').lower() == 'true' - self._schema_cache = {} - def schema_fields_func(schema_id: int): - if schema_id not in self._schema_cache: + if schema_id not in self.manifest_file_manager.schema_cache: schema = self.table.schema_manager.read_schema(schema_id) - self._schema_cache[schema_id] = schema - return self._schema_cache[schema_id].fields if self._schema_cache[schema_id] else [] + self.manifest_file_manager.schema_cache[schema_id] = schema + return self.manifest_file_manager.schema_cache[schema_id].fields if self.manifest_file_manager.schema_cache[ + schema_id] else [] self.simple_stats_evolutions = SimpleStatsEvolutions( schema_fields_func, diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index 000e272e39..5c75cf6506 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -19,7 +19,7 @@ import os from abc import ABC, abstractmethod from functools import partial -from typing import List, Optional, Tuple, Any +from typing import List, Optional, Tuple, Any, Dict from pypaimon.common.core_options import CoreOptions from pypaimon.common.predicate import Predicate @@ -46,6 +46,7 @@ from pypaimon.read.reader.key_value_wrap_reader import KeyValueWrapReader from pypaimon.read.reader.sort_merge_reader import SortMergeReaderWithMinHeap from pypaimon.read.split import Split from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.schema.table_schema import TableSchema KEY_PREFIX = "_KEY_" KEY_FIELD_ID_START = 1000000 @@ -55,12 +56,14 @@ NULL_FIELD_INDEX = -1 class SplitRead(ABC): """Abstract base class for split reading operations.""" - def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField], split: Split): + def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField], split: Split, + schema_fields_cache: Dict): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table self.predicate = predicate - self.push_down_predicate = self._push_down_predicate() + predicate_tuple = self._push_down_predicate() + self.push_down_predicate, self.predicate_fields = predicate_tuple if predicate_tuple else (None, None) self.split = split self.value_arity = len(read_type) @@ -68,6 +71,7 @@ class SplitRead(ABC): self.read_fields = read_type if isinstance(self, MergeFileSplitRead): self.read_fields = self._create_key_value_fields(read_type) + self.schema_fields_cache = schema_fields_cache def _push_down_predicate(self) -> Any: if self.predicate is None: @@ -84,21 +88,26 @@ class SplitRead(ABC): def create_reader(self) -> RecordReader: """Create a record reader for the given split.""" - def file_reader_supplier(self, file_path: str, for_merge_read: bool, read_fields: List[str]): + def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, read_fields: List[str]): + read_file_fields, file_filter = self._get_schema(file.schema_id, read_fields) + if not file_filter: + return None + + file_path = file.file_path _, extension = os.path.splitext(file_path) file_format = extension[1:] format_reader: RecordBatchReader if file_format == CoreOptions.FILE_FORMAT_AVRO: - format_reader = FormatAvroReader(self.table.file_io, file_path, read_fields, + format_reader = FormatAvroReader(self.table.file_io, file_path, read_file_fields, self.read_fields, self.push_down_predicate) elif file_format == CoreOptions.FILE_FORMAT_BLOB: blob_as_descriptor = CoreOptions.get_blob_as_descriptor(self.table.options) - format_reader = FormatBlobReader(self.table.file_io, file_path, read_fields, + format_reader = FormatBlobReader(self.table.file_io, file_path, read_file_fields, self.read_fields, self.push_down_predicate, blob_as_descriptor) elif file_format == CoreOptions.FILE_FORMAT_PARQUET or file_format == CoreOptions.FILE_FORMAT_ORC: format_reader = FormatPyArrowReader(self.table.file_io, file_format, file_path, - read_fields, self.push_down_predicate) + read_file_fields, self.push_down_predicate) else: raise ValueError(f"Unexpected file format: {file_format}") @@ -111,6 +120,24 @@ class SplitRead(ABC): return DataFileBatchReader(format_reader, index_mapping, partition_info, None, self.table.table_schema.fields) + def _get_schema(self, schema_id: int, read_fields) -> TableSchema: + if schema_id not in self.schema_fields_cache[0]: + schema = self.table.schema_manager.read_schema(schema_id) + if schema is None: + raise ValueError(f"Schema {schema_id} not found") + self.schema_fields_cache[0][schema_id] = schema + schema = self.schema_fields_cache[0][schema_id] + fields_key = (schema_id, tuple(read_fields)) + if fields_key not in self.schema_fields_cache[1]: + schema_field_names = set(field.name for field in schema.fields) + if self.table.is_primary_key_table: + schema_field_names.add('_SEQUENCE_NUMBER') + schema_field_names.add('_VALUE_KIND') + self.schema_fields_cache[1][fields_key] = ( + [read_field for read_field in read_fields if read_field in schema_field_names], + False if self.predicate_fields and self.predicate_fields - schema_field_names else True) + return self.schema_fields_cache[1][fields_key] + @abstractmethod def _get_all_data_fields(self): """Get all data fields""" @@ -263,10 +290,10 @@ class RawFileSplitRead(SplitRead): def create_reader(self) -> RecordReader: data_readers = [] - for file_path in self.split.file_paths: + for file in self.split.files: supplier = partial( self.file_reader_supplier, - file_path=file_path, + file=file, for_merge_read=False, read_fields=self._get_final_read_data_fields(), ) @@ -289,10 +316,10 @@ class RawFileSplitRead(SplitRead): class MergeFileSplitRead(SplitRead): - def kv_reader_supplier(self, file_path): + def kv_reader_supplier(self, file): reader_supplier = partial( self.file_reader_supplier, - file_path=file_path, + file=file, for_merge_read=True, read_fields=self._get_final_read_data_fields() ) @@ -303,7 +330,7 @@ class MergeFileSplitRead(SplitRead): for sorter_run in section: data_readers = [] for file in sorter_run.files: - supplier = partial(self.kv_reader_supplier, file.file_path) + supplier = partial(self.kv_reader_supplier, file) data_readers.append(supplier) readers.append(ConcatRecordReader(data_readers)) return SortMergeReaderWithMinHeap(readers, self.table.table_schema) @@ -468,7 +495,7 @@ class DataEvolutionSplitRead(SplitRead): def _create_file_reader(self, file: DataFileMeta, read_fields: [str]) -> RecordReader: """Create a file reader for a single file.""" - return self.file_reader_supplier(file_path=file.file_path, for_merge_read=False, read_fields=read_fields) + return self.file_reader_supplier(file=file, for_merge_read=False, read_fields=read_fields) def _split_field_bunches(self, need_merge_files: List[DataFileMeta]) -> List[FieldBunch]: """Split files into field bunches.""" diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 31545e4ea4..6cd544e745 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -39,6 +39,7 @@ class TableRead: self.table: FileStoreTable = table self.predicate = predicate self.read_type = read_type + self.schema_fields_cache = ({}, {}) def to_iterator(self, splits: List[Split]) -> Iterator: def _record_generator(): @@ -57,10 +58,32 @@ class TableRead: batch_iterator = self._arrow_batch_generator(splits, schema) return pyarrow.ipc.RecordBatchReader.from_batches(schema, batch_iterator) + def _pad_batch_to_schema(self, batch: pyarrow.RecordBatch, target_schema): + columns = [] + num_rows = batch.num_rows + + for field in target_schema: + if field.name in batch.column_names: + col = batch.column(field.name) + else: + col = pyarrow.nulls(num_rows, type=field.type) + columns.append(col) + + return pyarrow.RecordBatch.from_arrays(columns, schema=target_schema) + def to_arrow(self, splits: List[Split]) -> Optional[pyarrow.Table]: batch_reader = self.to_arrow_batch_reader(splits) - arrow_table = batch_reader.read_all() - return arrow_table + + schema = PyarrowFieldParser.from_paimon_schema(self.read_type) + table_list = [] + for batch in iter(batch_reader.read_next_batch, None): + table_list.append(batch) if schema == batch.schema \ + else table_list.append(self._pad_batch_to_schema(batch, schema)) + + if not table_list: + return pyarrow.Table.from_arrays([pyarrow.array([], type=field.type) for field in schema], schema=schema) + else: + return pyarrow.Table.from_batches(table_list) def _arrow_batch_generator(self, splits: List[Split], schema: pyarrow.Schema) -> Iterator[pyarrow.RecordBatch]: chunk_size = 65536 @@ -112,21 +135,24 @@ class TableRead: table=self.table, predicate=self.predicate, read_type=self.read_type, - split=split + split=split, + schema_fields_cache=self.schema_fields_cache ) elif self.table.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, 'false').lower() == 'true': return DataEvolutionSplitRead( table=self.table, predicate=self.predicate, read_type=self.read_type, - split=split + split=split, + schema_fields_cache=self.schema_fields_cache ) else: return RawFileSplitRead( table=self.table, predicate=self.predicate, read_type=self.read_type, - split=split + split=split, + schema_fields_cache=self.schema_fields_cache ) @staticmethod diff --git a/paimon-python/pypaimon/tests/pvfs_test.py b/paimon-python/pypaimon/tests/pvfs_test.py index 29ef979f9e..7bebceb96e 100644 --- a/paimon-python/pypaimon/tests/pvfs_test.py +++ b/paimon-python/pypaimon/tests/pvfs_test.py @@ -151,7 +151,8 @@ class PVFSTest(unittest.TestCase): self.assertEqual(table_virtual_path, self.pvfs.info(table_virtual_path).get('name')) self.assertEqual(True, self.pvfs.exists(database_virtual_path)) user_dirs = self.pvfs.ls(f"pvfs://{self.catalog}/{self.database}/{self.table}", detail=False) - self.assertSetEqual(set(user_dirs), {f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'}) + self.assertSetEqual(set(user_dirs), {f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}', + f'pvfs://{self.catalog}/{self.database}/{self.table}/schema'}) data_file_name = 'data.txt' data_file_path = f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}' diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py b/paimon-python/pypaimon/tests/reader_append_only_test.py index 0367ab409c..3a99196854 100644 --- a/paimon-python/pypaimon/tests/reader_append_only_test.py +++ b/paimon-python/pypaimon/tests/reader_append_only_test.py @@ -38,7 +38,7 @@ class AoReaderTest(unittest.TestCase): cls.catalog = CatalogFactory.create({ 'warehouse': cls.warehouse }) - cls.catalog.create_database('default', False) + cls.catalog.create_database('default', True) cls.pa_schema = pa.schema([ ('user_id', pa.int32()), diff --git a/paimon-python/pypaimon/tests/rest/rest_server.py b/paimon-python/pypaimon/tests/rest/rest_server.py index 8f7cf23944..d2908e59be 100644 --- a/paimon-python/pypaimon/tests/rest/rest_server.py +++ b/paimon-python/pypaimon/tests/rest/rest_server.py @@ -428,12 +428,15 @@ class RESTCatalogServer: if create_table.identifier.get_full_name() in self.table_metadata_store: raise TableAlreadyExistException(create_table.identifier) table_metadata = self._create_table_metadata( - create_table.identifier, 1, create_table.schema, str(uuid.uuid4()), False + create_table.identifier, 0, create_table.schema, str(uuid.uuid4()), False ) self.table_metadata_store.update({create_table.identifier.get_full_name(): table_metadata}) - table_dir = Path(self.data_path) / self.warehouse / database_name / create_table.identifier.object_name + table_dir = Path( + self.data_path) / self.warehouse / database_name / create_table.identifier.object_name / 'schema' if not table_dir.exists(): table_dir.mkdir(parents=True) + with open(table_dir / "schema-0", "w") as f: + f.write(JSON.to_json(table_metadata.schema, indent=2)) return self._mock_response("", 200) return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405) diff --git a/paimon-python/pypaimon/tests/schema_evolution_read_test.py b/paimon-python/pypaimon/tests/schema_evolution_read_test.py new file mode 100644 index 0000000000..dde1f2c15f --- /dev/null +++ b/paimon-python/pypaimon/tests/schema_evolution_read_test.py @@ -0,0 +1,328 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + +from pypaimon.schema.schema_manager import SchemaManager +from pypaimon.schema.table_schema import TableSchema + + +class SchemaEvolutionReadTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({ + 'warehouse': cls.warehouse + }) + cls.catalog.create_database('default', False) + + cls.pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()) + ]) + cls.raw_data = { + 'user_id': [1, 2, 3, 4, 5], + 'item_id': [1001, 1002, 1003, 1004, 1005], + 'behavior': ['a', 'b', 'c', None, 'e'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p2'], + } + cls.expected = pa.Table.from_pydict(cls.raw_data, schema=cls.pa_schema) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def test_schema_evolution(self): + # schema 0 + pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('dt', pa.string()) + ]) + schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_sample', schema, False) + table1 = self.catalog.get_table('default.test_sample') + write_builder = table1.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'dt': ['p1', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # schema 1 add behavior column + pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('dt', pa.string()), + ('behavior', pa.string()) + ]) + schema2 = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_schema_evolution', schema2, False) + table2 = self.catalog.get_table('default.test_schema_evolution') + table2.table_schema.id = 1 + write_builder = table2.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'dt': ['p2', 'p1', 'p2', 'p2'], + 'behavior': ['e', 'f', 'g', 'h'], + } + pa_table = pa.Table.from_pydict(data2, schema=pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # write schema-0 and schema-1 to table2 + schema_manager = SchemaManager(table2.file_io, table2.table_path) + schema_manager.commit(TableSchema.from_schema(schema_id=0, schema=schema)) + schema_manager.commit(TableSchema.from_schema(schema_id=1, schema=schema2)) + + splits = self._scan_table(table1.new_read_builder()) + read_builder = table2.new_read_builder() + splits2 = self._scan_table(read_builder) + splits.extend(splits2) + + table_read = read_builder.new_read() + actual = table_read.to_arrow(splits) + expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 4, 3, 5, 7, 8, 6], + 'item_id': [1001, 1002, 1004, 1003, 1005, 1007, 1008, 1006], + 'dt': ["p1", "p1", "p1", "p2", "p2", "p2", "p2", "p1"], + 'behavior': [None, None, None, None, "e", "g", "h", "f"], + }, schema=pa_schema) + self.assertEqual(expected, actual) + + def test_schema_evolution_with_read_filter(self): + # schema 0 + pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('dt', pa.string()) + ]) + schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_schema_evolution_with_filter', schema, False) + table1 = self.catalog.get_table('default.test_schema_evolution_with_filter') + write_builder = table1.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'dt': ['p1', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # schema 1 add behavior column + pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('dt', pa.string()), + ('behavior', pa.string()) + ]) + schema2 = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_schema_evolution_with_filter2', schema2, False) + table2 = self.catalog.get_table('default.test_schema_evolution_with_filter2') + table2.table_schema.id = 1 + write_builder = table2.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'dt': ['p2', 'p1', 'p2', 'p2'], + 'behavior': ['e', 'f', 'g', 'h'], + } + pa_table = pa.Table.from_pydict(data2, schema=pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # write schema-0 and schema-1 to table2 + schema_manager = SchemaManager(table2.file_io, table2.table_path) + schema_manager.commit(TableSchema.from_schema(schema_id=0, schema=schema)) + schema_manager.commit(TableSchema.from_schema(schema_id=1, schema=schema2)) + # behavior filter + splits = self._scan_table(table1.new_read_builder()) + + read_builder = table2.new_read_builder() + predicate_builder = read_builder.new_predicate_builder() + predicate = predicate_builder.not_equal('behavior', "g") + splits2 = self._scan_table(read_builder.with_filter(predicate)) + for split in splits2: + for file in split.files: + file.schema_id = 1 + splits.extend(splits2) + + table_read = read_builder.new_read() + actual = table_read.to_arrow(splits) + expected = pa.Table.from_pydict({ + 'user_id': [5, 8, 6], + 'item_id': [1005, 1008, 1006], + 'dt': ["p2", "p2", "p1"], + 'behavior': ["e", "h", "f"], + }, schema=pa_schema) + self.assertEqual(expected, actual) + # user_id filter + splits = self._scan_table(table1.new_read_builder()) + + read_builder = table2.new_read_builder() + predicate_builder = read_builder.new_predicate_builder() + predicate = predicate_builder.less_than('user_id', 6) + splits2 = self._scan_table(read_builder.with_filter(predicate)) + self.assertEqual(1, len(splits2)) + for split in splits2: + for file in split.files: + file.schema_id = 1 + splits.extend(splits2) + + table_read = read_builder.new_read() + actual = table_read.to_arrow(splits) + expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 4, 3, 5], + 'item_id': [1001, 1002, 1004, 1003, 1005], + 'dt': ["p1", "p1", "p1", "p2", "p2"], + 'behavior': [None, None, None, None, "e"], + }, schema=pa_schema) + self.assertEqual(expected, actual) + + def test_schema_evolution_with_scan_filter(self): + # schema 0 + pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('dt', pa.string()) + ]) + schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_schema_evolution1', schema, False) + table1 = self.catalog.get_table('default.test_schema_evolution1') + write_builder = table1.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'dt': ['p1', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # schema 1 add behavior column + pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()) + ]) + schema2 = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_schema_evolution2', schema2, False) + table2 = self.catalog.get_table('default.test_schema_evolution2') + table2.table_schema.id = 1 + write_builder = table2.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'behavior': ['e', 'f', 'g', 'h'], + 'dt': ['p2', 'p1', 'p2', 'p2'], + } + pa_table = pa.Table.from_pydict(data2, schema=pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # write schema-0 and schema-1 to table2 + schema_manager = SchemaManager(table2.file_io, table2.table_path) + schema_manager.commit(TableSchema.from_schema(schema_id=0, schema=schema)) + schema_manager.commit(TableSchema.from_schema(schema_id=1, schema=schema2)) + # scan filter for schema evolution + latest_snapshot = table1.new_read_builder().new_scan().starting_scanner.snapshot_manager.get_latest_snapshot() + table2.table_path = table1.table_path + new_read_buidler = table2.new_read_builder() + predicate_builder = new_read_buidler.new_predicate_builder() + predicate = predicate_builder.less_than('user_id', 3) + new_scan = new_read_buidler.with_filter(predicate).new_scan() + manifest_files = new_scan.starting_scanner.manifest_list_manager.read_all(latest_snapshot) + entries = new_scan.starting_scanner.read_manifest_entries(manifest_files) + self.assertEqual(1, len(entries)) # verify scan filter success for schema evolution + + def _write_test_table(self, table): + write_builder = table.new_batch_write_builder() + + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'behavior': ['a', 'b', 'c', None], + 'dt': ['p1', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'behavior': ['e', 'f', 'g', 'h'], + 'dt': ['p2', 'p1', 'p2', 'p2'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + def _scan_table(self, read_builder): + splits = read_builder.new_scan().plan().splits() + return splits
