This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 0978e4c175 [ray] Introduce Ray Data merge into (#8028)
0978e4c175 is described below
commit 0978e4c17512a6bbf441801ae1a3f9a83bf31eb5
Author: XiaoHongbo <[email protected]>
AuthorDate: Mon Jun 1 21:54:43 2026 +0800
[ray] Introduce Ray Data merge into (#8028)
---
.github/workflows/paimon-python-checks.yml | 4 +-
docs/docs/pypaimon/ray-data.md | 49 +++
paimon-python/dev/requirements-dev.txt | 5 +-
paimon-python/pypaimon/ray/__init__.py | 13 +-
.../pypaimon/ray/data_evolution_merge_into.py | 450 ++++++++++++++++++++
.../pypaimon/ray/data_evolution_merge_join.py | 353 ++++++++++++++++
.../pypaimon/ray/data_evolution_merge_transform.py | 121 ++++++
.../tests/ray_data_evolution_merge_into_test.py | 470 +++++++++++++++++++++
paimon-python/pypaimon/write/ray_datasink.py | 26 +-
.../pypaimon/write/table_update_by_row_id.py | 61 ++-
10 files changed, 1520 insertions(+), 32 deletions(-)
diff --git a/.github/workflows/paimon-python-checks.yml
b/.github/workflows/paimon-python-checks.yml
index 864d10ecee..6a88767590 100755
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -133,7 +133,7 @@ jobs:
else
python -m pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cpu
- python -m pip install pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0
+ python -m pip install pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0
python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION
}}' -i https://pypi.org/simple/
if python -c "import sys; sys.exit(0 if sys.version_info >= (3,
11) else 1)"; then
python -m pip install vortex-data==0.70.0
@@ -184,7 +184,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cpu
- python -m pip install pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0
+ python -m pip install pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0
python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION
}}' -i https://pypi.org/simple/
- name: Run lint-python.sh
shell: bash
diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index b411561322..e19e987c4b 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -314,3 +314,52 @@ write_builder = table.new_batch_write_builder().overwrite()
# overwrite partition 'dt=2024-01-01'
write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'})
```
+
+## Merge Into
+
+`merge_into` updates (and optionally inserts) rows of a **data-evolution**
table
+from a source, like SQL `MERGE INTO`. Matched rows are updated in place by
+`_ROW_ID`; only the touched columns are rewritten. Requires `ray >= 2.50` and a
+target table with `'data-evolution.enabled'` and `'row-tracking.enabled'` set.
+
+```python
+from pypaimon.ray import merge_into, WhenMatched, WhenNotMatched
+
+metrics = merge_into(
+ target="database_name.table_name",
+ source=ray_dataset, # ray.data.Dataset / pa.Table / pandas /
table-name str
+ catalog_options={"warehouse": "/path/to/warehouse"},
+ on=["id"], # or {"target_col": "source_col"} for renamed
keys
+ when_matched=[WhenMatched(update="*")],
+ when_not_matched=[WhenNotMatched(insert="*")], # optional
+)
+print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0}
+```
+
+- `update` / `insert`: only `"*"` is supported in this PR. A future follow-up
+ will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are
+ analyzable string expressions (`"s.<col>"`, `"t.<col>"`, or literals),
+ not Python callables.
+- `condition`: reserved for a future follow-up; passing a non-None value
+ currently raises `NotImplementedError`.
+
+**Parameters:**
+- `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a
+ Paimon table identifier string. When a string is passed, it reads the table
+ from the same `catalog_options` at the latest snapshot.
+- `on`: key columns, or `{target_col: source_col}` for renamed keys.
+- `num_partitions`: shuffle parallelism for the join and the write; defaults to
+ `max(1, cluster_cpus * 2)`. Raise it for large merges on big clusters.
+- `ray_remote_args`: Ray remote options applied to the merge's map/group
+ tasks (update transform, group write, insert transform).
+- `concurrency`: scheduling for the insert sink.
+
+**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. In this PR
+every matched row is updated, so `num_matched` always equals `num_updated`
+and `num_unchanged` is always `0`; conditional clauses (added later) can
+make `num_unchanged > 0`.
+
+**Notes:**
+- Blob columns are not written by `merge_into`: update leaves the existing
+ `.blob` files untouched, and insert fills blob columns with `NULL`. The
+ source data does not need to (and should not) carry blob columns.
diff --git a/paimon-python/dev/requirements-dev.txt
b/paimon-python/dev/requirements-dev.txt
index d4e9a0645b..9ef88817f7 100644
--- a/paimon-python/dev/requirements-dev.txt
+++ b/paimon-python/dev/requirements-dev.txt
@@ -21,8 +21,9 @@
duckdb==1.3.2
flake8==4.0.1
pytest~=7.0
-# Ray: 2.48+ has no wheel for Python 3.8; use 2.10.0 on 3.8, 2.48.0 on 3.9+
-ray>=2.10.0
+# merge_into needs Dataset.join (added in Ray 2.50). Python 3.8 has no 2.50
wheel.
+ray>=2.10.0; python_version < "3.9"
+ray>=2.50.0; python_version >= "3.9"
requests
parameterized
# Vortex 0.71.0 regresses native predicate pushdown on single-row files.
diff --git a/paimon-python/pypaimon/ray/__init__.py
b/paimon-python/pypaimon/ray/__init__.py
index f36eb0253d..9161f3cbb3 100644
--- a/paimon-python/pypaimon/ray/__init__.py
+++ b/paimon-python/pypaimon/ray/__init__.py
@@ -16,5 +16,16 @@
# under the License.
from pypaimon.ray.ray_paimon import read_paimon, write_paimon
+from pypaimon.ray.data_evolution_merge_into import (
+ WhenMatched,
+ WhenNotMatched,
+ merge_into,
+)
-__all__ = ["read_paimon", "write_paimon"]
+__all__ = [
+ "read_paimon",
+ "write_paimon",
+ "merge_into",
+ "WhenMatched",
+ "WhenNotMatched",
+]
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py
b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
new file mode 100644
index 0000000000..7ab1ce70f8
--- /dev/null
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -0,0 +1,450 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""MERGE INTO ... USING ... for Paimon data-evolution tables via Ray
Datasets."""
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
+
+import pyarrow as pa
+
+from pypaimon.ray.data_evolution_merge_join import (
+ build_matched_update_ds,
+ build_not_matched_insert_ds,
+ distributed_update_apply,
+ distributed_write_collect_msgs,
+)
+from pypaimon.ray.data_evolution_merge_transform import (
+ OnSpec,
+ SetSpec,
+ WhenMatched,
+ WhenNotMatched,
+ _NormalizedClause,
+)
+
+__all__ = ["merge_into", "WhenMatched", "WhenNotMatched"]
+
+
+@dataclass(frozen=True)
+class _PrepareCtx:
+ """Bag of values _prepare hands to _build_datasets."""
+ target_on_cols: List[str]
+ source_on_cols: List[str]
+ settable_field_names: List[str]
+ full_target_field_names: List[str]
+ update_pa_schema: pa.Schema
+ full_pa_schema: pa.Schema
+ catalog_options: Dict[str, str]
+
+
+def merge_into(
+ target: str,
+ source: Any,
+ catalog_options: Dict[str, str],
+ *,
+ on: OnSpec,
+ when_matched: Sequence[WhenMatched] = (),
+ when_not_matched: Sequence[WhenNotMatched] = (),
+ num_partitions: Optional[int] = None,
+ ray_remote_args: Optional[Dict[str, Any]] = None,
+ concurrency: Optional[int] = None,
+) -> Dict[str, int]:
+ _require_ray_join()
+ num_partitions = _resolve_num_partitions(num_partitions)
+
+ table, source_ds, matched_specs, not_matched_specs, ctx = _prepare(
+ target, source, catalog_options,
+ list(when_matched), list(when_not_matched), on,
+ )
+ base_snapshot = table.snapshot_manager().get_latest_snapshot()
+
+ update_ds, insert_ds, update_cols_union = _build_datasets(
+ target, source_ds, matched_specs, not_matched_specs,
+ ctx, base_snapshot, num_partitions, ray_remote_args,
+ )
+
+ return _execute_and_commit(
+ table, update_ds, insert_ds, update_cols_union,
+ base_snapshot, num_partitions,
+ ray_remote_args, concurrency,
+ )
+
+
+def _prepare(target, source, catalog_options, when_matched, when_not_matched,
on):
+ if not when_matched and not when_not_matched:
+ raise ValueError(
+ "At least one of when_matched or when_not_matched must be
non-empty."
+ )
+ if len(when_matched) > 1 or len(when_not_matched) > 1:
+ raise NotImplementedError(
+ "merge_into currently supports a single WhenMatched and a single "
+ "WhenNotMatched clause; multi-clause fall-through will be added "
+ "in a follow-up PR."
+ )
+ for clause in list(when_matched) + list(when_not_matched):
+ if clause.condition is not None:
+ raise NotImplementedError(
+ "merge_into does not yet support condition expressions; "
+ "this will be added in a follow-up PR."
+ )
+ target_on_cols, source_on_cols = _normalize_on(on)
+
+ from pypaimon.catalog.catalog_factory import CatalogFactory
+
+ catalog = CatalogFactory.create(catalog_options)
+ table = catalog.get_table(target)
+ if not table.options.data_evolution_enabled():
+ raise ValueError(
+ f"merge_into requires 'data-evolution.enabled' = 'true' on
'{target}'."
+ )
+ if not table.options.row_tracking_enabled():
+ raise ValueError(
+ f"merge_into requires 'row-tracking.enabled' = 'true' on
'{target}'."
+ )
+
+ blob_cols = _blob_col_names(table)
+ full_target_field_names = list(table.field_names)
+ # SET specs only cover non-blob columns: update can't rewrite blob files
+ # (data evolution puts them in dedicated .blob files), and insert leaves
+ # blob columns null since the source can't carry them through SET="*".
+ settable_field_names = [
+ c for c in full_target_field_names if c not in blob_cols
+ ]
+ on_map = dict(zip(target_on_cols, source_on_cols))
+ matched_specs = [
+ _NormalizedClause(
+ spec=_normalize_set_spec(
+ c.update, settable_field_names, on_map,
+ ),
+ )
+ for c in when_matched
+ ]
+ not_matched_specs = [
+ _NormalizedClause(
+ spec=_normalize_set_spec(
+ c.insert, settable_field_names, on_map,
+ ),
+ )
+ for c in when_not_matched
+ ]
+
+ source_ds = _normalize_source(source, catalog_options)
+ _validate_source_on_cols(source_ds, source_on_cols)
+ _validate_source_has_target_cols(
+ source_ds, settable_field_names, on_map,
+ )
+
+ from pypaimon.schema.data_types import PyarrowFieldParser
+ full_pa_schema = PyarrowFieldParser.from_paimon_schema(
+ table.table_schema.fields
+ )
+ # update_pa_schema strips blob (only non-blob cols are written by the
+ # update path); insert_pa_schema is the full table schema so the writer
+ # gets every column (blob columns end up null).
+ update_pa_schema = pa.schema(
+ [full_pa_schema.field(c) for c in settable_field_names]
+ )
+ ctx = _PrepareCtx(
+ target_on_cols=target_on_cols,
+ source_on_cols=source_on_cols,
+ settable_field_names=settable_field_names,
+ full_target_field_names=full_target_field_names,
+ update_pa_schema=update_pa_schema,
+ full_pa_schema=full_pa_schema,
+ catalog_options=catalog_options,
+ )
+ return table, source_ds, matched_specs, not_matched_specs, ctx
+
+
+def _build_datasets(
+ target, source_ds, matched_specs, not_matched_specs,
+ ctx: "_PrepareCtx", base_snapshot, num_partitions, ray_remote_args,
+):
+ # Pin every target read to base_snapshot so all branches see the same
+ # snapshot the caller observed; otherwise concurrent commits in between
+ # would mix data from different snapshots.
+ base_snapshot_id = base_snapshot.id if base_snapshot is not None else None
+
+ update_ds = None
+ insert_ds = None
+ update_cols_union: List[str] = []
+
+ # Mirror Spark: matched/not-matched run as two independent joins
+ # (inner / left_anti). One unified left_outer join would force
+ # joined.materialize() to feed both branches, which can OOM on large
merges.
+ if matched_specs and base_snapshot is not None:
+ update_cols_union = _union_update_cols(matched_specs)
+ update_ds = build_matched_update_ds(
+ target_identifier=target,
+ source_ds=source_ds,
+ target_on=ctx.target_on_cols,
+ source_on=ctx.source_on_cols,
+ clauses=matched_specs,
+ target_field_names=ctx.settable_field_names,
+ target_pa_schema=ctx.update_pa_schema,
+ update_cols=update_cols_union,
+ catalog_options=ctx.catalog_options,
+ num_partitions=num_partitions,
+ resolve_target_projection=_resolve_target_projection,
+ snapshot_id=base_snapshot_id,
+ ray_remote_args=ray_remote_args,
+ )
+
+ if not_matched_specs:
+ # Insert writes the full target schema; SET spec only covers
+ # settable cols, so blob columns fall through to null.
+ insert_ds = build_not_matched_insert_ds(
+ target_identifier=target,
+ source_ds=source_ds,
+ target_on=ctx.target_on_cols,
+ source_on=ctx.source_on_cols,
+ clauses=not_matched_specs,
+ target_field_names=ctx.full_target_field_names,
+ target_pa_schema=ctx.full_pa_schema,
+ catalog_options=ctx.catalog_options,
+ num_partitions=num_partitions,
+ snapshot_id=base_snapshot_id,
+ target_empty=base_snapshot is None,
+ ray_remote_args=ray_remote_args,
+ )
+
+ return update_ds, insert_ds, update_cols_union
+
+
+def _execute_and_commit(
+ table, update_ds, insert_ds, update_cols_union,
+ base_snapshot, num_partitions,
+ ray_remote_args, concurrency,
+):
+ update_msgs: list = []
+ num_updated = 0
+ if update_ds is not None:
+ try:
+ update_msgs, num_updated = distributed_update_apply(
+ update_ds, table, update_cols_union,
+ num_partitions=num_partitions,
+ ray_remote_args=ray_remote_args,
+ base_snapshot_id=(
+ base_snapshot.id
+ if base_snapshot is not None else None
+ ),
+ )
+ except Exception as e:
+ _reraise_inner(e)
+
+ all_msgs: list = list(update_msgs)
+ num_inserted = 0
+ if insert_ds is not None:
+ try:
+ insert_msgs = distributed_write_collect_msgs(
+ insert_ds, table,
+ ray_remote_args=ray_remote_args, concurrency=concurrency,
+ )
+ except Exception as e:
+ _reraise_inner(e)
+ num_inserted = sum(
+ f.row_count for m in insert_msgs for f in m.new_files
+ )
+ all_msgs.extend(insert_msgs)
+ # TODO: add global-index update action check after PR #8045 merges
+ if all_msgs:
+ wb = table.new_batch_write_builder()
+ tc = wb.new_commit()
+ tc.commit(all_msgs)
+ tc.close()
+
+ # MVP has no condition, so every matched row is updated; num_unchanged
+ # is always 0. Kept in the dict for API stability when condition lands.
+ return {
+ "num_matched": num_updated,
+ "num_inserted": num_inserted,
+ "num_unchanged": 0,
+ }
+
+
+def _normalize_on(on: OnSpec) -> Tuple[List[str], List[str]]:
+ if isinstance(on, Mapping):
+ target_cols = list(on.keys())
+ source_cols = list(on.values())
+ else:
+ target_cols = list(on)
+ source_cols = list(on)
+ if not target_cols:
+ raise ValueError("'on' must be non-empty.")
+ return target_cols, source_cols
+
+
+def _resolve_num_partitions(num_partitions: Optional[int]) -> int:
+ if num_partitions is not None:
+ return num_partitions
+ try:
+ import ray
+
+ cpus = int(ray.cluster_resources().get("CPU", 4))
+ return max(1, cpus * 2)
+ except Exception:
+ return 4
+
+
+def _require_ray_join() -> None:
+ import ray
+ from packaging.version import parse
+
+ if parse(ray.__version__) < parse("2.50.0"):
+ raise RuntimeError(
+ f"merge_into requires ray>=2.50; "
+ f"installed ray is {ray.__version__}."
+ )
+
+
+def _blob_col_names(table) -> set:
+ return {
+ f.name
+ for f in table.table_schema.fields
+ if getattr(f.type, "type", None) == "BLOB"
+ }
+
+
+def _reraise_inner(err: BaseException) -> None:
+ """Unwrap Ray's RayTaskError so callers see the worker-side exception."""
+ inner = err
+ cause = getattr(err, "cause", None) or getattr(err, "__cause__", None)
+ while cause is not None:
+ inner = cause
+ cause = getattr(inner, "cause", None) or getattr(inner, "__cause__",
None)
+ if inner is err:
+ raise err
+ raise inner from err
+
+
+def _union_update_cols(clauses: List[_NormalizedClause]) -> List[str]:
+ seen: List[str] = []
+ seen_set: set = set()
+ for clause in clauses:
+ for col in clause.spec.keys():
+ if col not in seen_set:
+ seen.append(col)
+ seen_set.add(col)
+ return seen
+
+
+def _needed_target_cols(
+ clauses: List[_NormalizedClause],
+ on: Sequence[str],
+ update_cols: Sequence[str],
+ all_target_cols: Sequence[str],
+) -> list:
+ # Target needs only: join keys, t.col refs, and cols that may fall back
+ # (not set by every clause). Cols all clauses set from source aren't read.
+ needed = set(on)
+ set_by_all = set(update_cols)
+ for clause in clauses:
+ for value in clause.spec.values():
+ if isinstance(value, str) and value.startswith("t."):
+ needed.add(value[2:])
+ set_by_all &= set(clause.spec.keys())
+ needed |= set(update_cols) - set_by_all
+ return [c for c in all_target_cols if c in needed]
+
+
+def _resolve_target_projection(
+ clauses: List[_NormalizedClause],
+ target_on: Sequence[str],
+ update_cols: Sequence[str],
+ target_field_names: Sequence[str],
+) -> list:
+ needed = set(_needed_target_cols(
+ clauses, target_on, update_cols, target_field_names,
+ ))
+ return [c for c in target_field_names if c in needed]
+
+
+def _normalize_set_spec(
+ spec: SetSpec,
+ target_field_names: Sequence[str],
+ on_map: Optional[Mapping[str, str]] = None,
+) -> Dict[str, Any]:
+ on_map = on_map or {}
+ if spec != "*":
+ raise NotImplementedError(
+ "merge_into currently only supports '*' for update/insert; "
+ "partial SET will be added in a follow-up PR."
+ )
+ # A renamed ON key resolves via the source's ON column, not its own name.
+ return {col: f"s.{on_map.get(col, col)}" for col in target_field_names}
+
+
+def _normalize_source(source: Any, catalog_options: Dict[str, str]):
+ import ray.data
+
+ if isinstance(source, ray.data.Dataset):
+ return source
+ if isinstance(source, str):
+ from pypaimon.ray.ray_paimon import read_paimon
+ return read_paimon(source, catalog_options)
+ if isinstance(source, pa.Table):
+ return ray.data.from_arrow(source)
+ try:
+ import pandas as pd
+ except ImportError:
+ pd = None
+ if pd is not None and isinstance(source, pd.DataFrame):
+ return ray.data.from_pandas(source)
+ raise TypeError(
+ "source must be a ray.data.Dataset, a Paimon table identifier string, "
+ f"a pyarrow.Table, or a pandas.DataFrame; got {type(source).__name__}."
+ )
+
+
+def _source_schema_or_raise(source_ds):
+ """Get source schema; refuse to proceed if Ray can't tell us the
columns."""
+ schema = source_ds.schema()
+ if schema is None:
+ raise ValueError(
+ "merge_into could not infer the source schema; pass a "
+ "ray.data.Dataset that has been materialized (e.g. via "
+ ".materialize()) or constructed from pyarrow/pandas."
+ )
+ return schema
+
+
+def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None:
+ names = set(_source_schema_or_raise(source_ds).names)
+ missing = [c for c in on if c not in names]
+ if missing:
+ raise ValueError(
+ f"'on' columns {missing} missing from source schema {list(names)}."
+ )
+
+
+def _validate_source_has_target_cols(
+ source_ds,
+ target_field_names: Sequence[str],
+ on_map: Mapping[str, str],
+) -> None:
+ """For update='*'/insert='*', source must carry every (non-blob) target
+ column; otherwise the SET spec resolves to null and silently overwrites."""
+ names = set(_source_schema_or_raise(source_ds).names)
+ expected = {on_map.get(c, c) for c in target_field_names}
+ missing = sorted(expected - names)
+ if missing:
+ raise ValueError(
+ f"source is missing target columns {missing}; "
+ f"update='*'/insert='*' requires the source to carry every "
+ f"(non-blob) target column."
+ )
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py
b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
new file mode 100644
index 0000000000..40e7994b2d
--- /dev/null
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
@@ -0,0 +1,353 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from typing import Any, Dict, List, Optional, Sequence, Tuple
+
+import pyarrow as pa
+
+from pypaimon.ray.data_evolution_merge_transform import (
+ _NormalizedClause,
+ build_update_schema,
+ vectorized_insert_transform,
+ vectorized_matched_transform,
+)
+
+
+def _map_kwargs(
+ ray_remote_args: Optional[Dict[str, Any]],
+) -> Dict[str, Any]:
+ """Build kwargs for map_batches/map_groups; spread ray_remote_args because
+ those APIs take remote options as **kwargs, not under a 'ray_remote_args'
+ key."""
+ kwargs: Dict[str, Any] = {"batch_format": "pyarrow"}
+ if ray_remote_args:
+ kwargs.update(ray_remote_args)
+ return kwargs
+
+
+def build_matched_update_ds(
+ *,
+ target_identifier: str,
+ source_ds,
+ target_on: Sequence[str],
+ source_on: Sequence[str],
+ clauses: List[_NormalizedClause],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+ update_cols: Sequence[str],
+ catalog_options: Dict[str, str],
+ num_partitions: int,
+ resolve_target_projection,
+ snapshot_id: Optional[int] = None,
+ ray_remote_args: Optional[Dict[str, Any]] = None,
+) -> Tuple:
+ from pypaimon.ray.ray_paimon import read_paimon
+ from pypaimon.table.special_fields import SpecialFields
+
+ row_id_name = SpecialFields.ROW_ID.name
+ needed_cols = resolve_target_projection(
+ clauses, target_on, update_cols, target_field_names,
+ )
+ projection = [row_id_name] + [c for c in needed_cols if c != row_id_name]
+
+ target_ds = read_paimon(
+ target_identifier, catalog_options,
+ projection=projection, snapshot_id=snapshot_id,
+ )
+ update_schema = build_update_schema(target_pa_schema, update_cols,
row_id_name)
+
+ target_renamed = target_ds.rename_columns(
+ {c: f"t.{c}" for c in target_ds.schema().names}
+ )
+ source_cols = list(source_ds.schema().names)
+ source_renamed = source_ds.rename_columns(
+ {c: f"s.{c}" for c in source_cols}
+ )
+
+ joined = target_renamed.join(
+ source_renamed,
+ join_type="inner",
+ num_partitions=num_partitions,
+ on=tuple(f"t.{c}" for c in target_on),
+ right_on=tuple(f"s.{c}" for c in source_on),
+ )
+
+ # MVP supports a single matched clause; future fan-out (conditions, multi-
+ # clause fall-through) must thread every clause's spec through the
+ # transform — guard so silent first-only behaviour can't sneak in.
+ assert len(clauses) == 1, (
+ f"build_matched_update_ds expected 1 clause, got {len(clauses)}"
+ )
+ spec = clauses[0].spec
+ captured_update_cols = list(update_cols)
+ captured_row_id_name = row_id_name
+ captured_on_pairs = list(zip(source_on, target_on))
+ captured_schema = update_schema
+
+ def _transform(batch: pa.Table) -> pa.Table:
+ return vectorized_matched_transform(
+ batch, spec, captured_on_pairs,
+ captured_update_cols, captured_row_id_name,
+ captured_schema,
+ )
+
+ return joined.map_batches(_transform, **_map_kwargs(ray_remote_args))
+
+
+def distributed_update_apply(
+ update_ds,
+ table,
+ write_update_cols: Sequence[str],
+ *,
+ num_partitions: int,
+ ray_remote_args: Optional[Dict[str, Any]] = None,
+ base_snapshot_id: Optional[int] = None,
+) -> Tuple[list, int]:
+ import numpy as np
+ import pickle
+ import uuid
+
+ import pyarrow.compute as pc
+ import ray
+
+ from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
+ from pypaimon.table.special_fields import SpecialFields
+ from pypaimon.write.table_update_by_row_id import TableUpdateByRowId
+
+ row_id_name = SpecialFields.ROW_ID.name
+ cols = list(write_update_cols)
+
+ for col in cols:
+ if col not in table.field_names:
+ raise ValueError(
+ f"Column '{col}' is not in target table schema."
+ )
+
+ planner = TableUpdateByRowId(
+ table,
+ "_merge_into_planner_" + uuid.uuid4().hex[:8],
+ BATCH_COMMIT_IDENTIFIER,
+ )
+ sorted_first_row_ids = list(planner.first_row_ids)
+ if not sorted_first_row_ids:
+ return [], 0
+
+ # Pin commit-time conflict check to the snapshot the join was built on,
+ # so concurrent commits between read and planner are detected.
+ check_from_snapshot = (
+ base_snapshot_id if base_snapshot_id is not None
+ else planner.snapshot_id
+ )
+
+ # Put file metadata into Ray's object store and pass a single ref to
+ # workers. Avoids per-task manifest re-scans (Jingsong review #6) and
+ # avoids serializing the metadata into every task's closure. Override
+ # snapshot_id with the join's base snapshot so commit-time conflict
+ # detection covers the read→planner window.
+ from dataclasses import replace
+ files_info = replace(
+ planner._snapshot_files_info(),
+ snapshot_id=check_from_snapshot,
+ )
+ precomputed_info_ref = ray.put(files_info)
+
+ frid_col = "_FIRST_ROW_ID"
+ captured_sorted = sorted_first_row_ids
+ captured_sorted_arr = np.asarray(captured_sorted, dtype=np.int64)
+ first = captured_sorted_arr[0]
+ total_row_count = planner.total_row_count
+
+ def _assign_frid(batch: pa.Table) -> pa.Table:
+ if batch.num_rows == 0:
+ return batch.append_column(
+ frid_col, pa.array([], type=pa.int64())
+ )
+ rid_col = batch.column(row_id_name)
+ if rid_col.null_count:
+ raise ValueError(
+ "_ROW_ID is null; planner snapshot is stale "
+ "or matched rows come from a different table."
+ )
+ rids = rid_col.to_numpy(zero_copy_only=False)
+ # Out-of-range _ROW_IDs would silently map via searchsorted
wrap-around.
+ out_of_range = (rids < first) | (rids >= total_row_count)
+ if out_of_range.any():
+ bad = rids[out_of_range][0]
+ raise ValueError(
+ f"_ROW_ID {bad} is out of valid range "
+ f"[{first}, {total_row_count}); planner snapshot "
+ f"is stale or matched rows come from a different "
+ f"table."
+ )
+ idx = np.searchsorted(
+ captured_sorted_arr, rids, side="right"
+ ) - 1
+ frids = captured_sorted_arr[idx]
+ return batch.append_column(
+ frid_col, pa.array(frids, type=pa.int64())
+ )
+
+ map_kwargs = _map_kwargs(ray_remote_args)
+ with_frid = update_ds.map_batches(_assign_frid, **map_kwargs)
+
+ captured_table = table
+ captured_cols = cols
+
+ def _apply_group(group: pa.Table) -> pa.Table:
+ if group.num_rows == 0:
+ return pa.Table.from_pydict({
+ "msgs_blob": pa.array([], type=pa.binary()),
+ "n_updated": pa.array([], type=pa.int64()),
+ })
+
+ if (
+ pc.count_distinct(group.column(row_id_name)).as_py()
+ != group.num_rows
+ ):
+ raise ValueError(
+ "MERGE matched multiple source rows to the same "
+ "target _ROW_ID. Deduplicate the source before "
+ "merging."
+ )
+
+ for_update = group.drop_columns([frid_col])
+ worker = TableUpdateByRowId(
+ captured_table,
+ "_merge_into_shard_" + uuid.uuid4().hex[:8],
+ BATCH_COMMIT_IDENTIFIER,
+ _precomputed_files_info=ray.get(precomputed_info_ref),
+ )
+ msgs = worker.update_columns(for_update, list(captured_cols))
+ return pa.Table.from_pydict({
+ "msgs_blob": [pickle.dumps(msgs)],
+ "n_updated": pa.array(
+ [for_update.num_rows], type=pa.int64()
+ ),
+ })
+
+ # One group per target data file; bounded by file count and num_partitions.
+ group_partitions = max(
+ 1, min(len(captured_sorted), num_partitions)
+ )
+ msgs_ds = with_frid.groupby(
+ frid_col, num_partitions=group_partitions
+ ).map_groups(_apply_group, **map_kwargs)
+
+ all_msgs: list = []
+ num_updated = 0
+ for batch in msgs_ds.iter_batches(batch_format="pyarrow"):
+ for blob in batch.column("msgs_blob").to_pylist():
+ all_msgs.extend(pickle.loads(blob))
+ for n in batch.column("n_updated").to_pylist():
+ num_updated += n
+ return all_msgs, num_updated
+
+
+def build_not_matched_insert_ds(
+ *,
+ target_identifier: str,
+ source_ds,
+ target_on: Sequence[str],
+ source_on: Sequence[str],
+ clauses: List[_NormalizedClause],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+ catalog_options: Dict[str, str],
+ num_partitions: int,
+ target_empty: bool = False,
+ snapshot_id: Optional[int] = None,
+ ray_remote_args: Optional[Dict[str, Any]] = None,
+):
+ from pypaimon.ray.ray_paimon import read_paimon
+ from pypaimon.ray.shuffle import _coerce_large_string_types
+
+ captured_field_names = list(target_field_names)
+ out_schema = target_pa_schema
+
+ source_cols = list(source_ds.schema().names)
+ source_renamed = source_ds.rename_columns(
+ {c: f"s.{c}" for c in source_cols}
+ )
+
+ if target_empty:
+ unmatched = source_renamed
+ else:
+ target_ds = read_paimon(
+ target_identifier, catalog_options,
+ projection=list(target_on), snapshot_id=snapshot_id,
+ )
+ target_renamed = target_ds.rename_columns(
+ {c: f"t.{c}" for c in target_on}
+ )
+ unmatched = source_renamed.join(
+ target_renamed,
+ join_type="left_anti",
+ num_partitions=num_partitions,
+ on=tuple(f"s.{c}" for c in source_on),
+ right_on=tuple(f"t.{c}" for c in target_on),
+ )
+
+ # MVP supports a single not-matched clause; see build_matched_update_ds
+ # for why we assert instead of silently dropping the rest.
+ assert len(clauses) == 1, (
+ f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}"
+ )
+ spec = clauses[0].spec
+
+ def _transform(batch: pa.Table) -> pa.Table:
+ return _coerce_large_string_types(
+ vectorized_insert_transform(
+ batch, spec, captured_field_names, out_schema
+ )
+ )
+
+ return unmatched.map_batches(
+ _transform, **_map_kwargs(ray_remote_args)
+ )
+
+
+def distributed_write_collect_msgs(
+ insert_ds,
+ table,
+ *,
+ ray_remote_args: Optional[Dict[str, Any]],
+ concurrency: Optional[int],
+) -> list:
+ from pypaimon.write.ray_datasink import PaimonDatasink
+
+ class _CollectingDatasink(PaimonDatasink):
+ def __init__(self, t):
+ super().__init__(t, overwrite=False)
+ self.collected: list = []
+
+ def on_write_complete(self, write_result):
+ self.collected = [
+ m
+ for batch in self._extract_write_returns(write_result)
+ for m in batch
+ if not m.is_empty()
+ ]
+
+ sink = _CollectingDatasink(table)
+ write_kwargs: Dict[str, Any] = {}
+ if ray_remote_args is not None:
+ write_kwargs["ray_remote_args"] = ray_remote_args
+ if concurrency is not None:
+ write_kwargs["concurrency"] = concurrency
+ insert_ds.write_datasink(sink, **write_kwargs)
+ return sink.collected
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
new file mode 100644
index 0000000000..0fc2d22f77
--- /dev/null
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
@@ -0,0 +1,121 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from dataclasses import dataclass
+from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
+
+import pyarrow as pa
+
+SetSpec = Union[str, Mapping[str, Any]]
+OnSpec = Union[Sequence[str], Mapping[str, str]]
+
+
+@dataclass
+class WhenMatched:
+ update: SetSpec
+ condition: Optional[str] = None
+
+
+@dataclass
+class WhenNotMatched:
+ insert: SetSpec
+ condition: Optional[str] = None
+
+
+@dataclass
+class _NormalizedClause:
+ spec: Dict[str, Any]
+
+
+def vectorized_matched_transform(
+ batch: pa.Table,
+ spec: Dict[str, Any],
+ on_pairs: Sequence[Tuple[str, str]],
+ update_cols: Sequence[str],
+ row_id_name: str,
+ update_schema: pa.Schema,
+) -> pa.Table:
+ available = set(batch.schema.names)
+ arrays: list = [batch.column(f"t.{row_id_name}")]
+ for col in update_cols:
+ out_type = update_schema.field(col).type
+ if col in spec:
+ arrays.append(
+ _resolve_spec_array(
+ spec[col], batch, available, on_pairs, out_type
+ )
+ )
+ else:
+ arrays.append(batch.column(f"t.{col}"))
+ return pa.Table.from_arrays(arrays, schema=update_schema)
+
+
+def vectorized_insert_transform(
+ batch: pa.Table,
+ spec: Dict[str, Any],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+) -> pa.Table:
+ available = set(batch.schema.names)
+ arrays: list = []
+ for col in target_field_names:
+ out_type = target_pa_schema.field(col).type
+ if col in spec:
+ arrays.append(
+ _resolve_spec_array(
+ spec[col], batch, available, (), out_type
+ )
+ )
+ else:
+ arrays.append(pa.nulls(batch.num_rows, type=out_type))
+ return pa.Table.from_arrays(arrays, schema=target_pa_schema)
+
+
+def build_update_schema(
+ target_pa_schema: pa.Schema,
+ update_cols: Sequence[str],
+ row_id_name: str,
+) -> pa.Schema:
+ return pa.schema(
+ [pa.field(row_id_name, pa.int64(), nullable=False)]
+ + [target_pa_schema.field(col) for col in update_cols]
+ )
+
+
+def _resolve_spec_array(
+ val: Any,
+ batch: pa.Table,
+ available: set,
+ on_pairs: Sequence[Tuple[str, str]],
+ out_type: pa.DataType,
+):
+ if isinstance(val, str) and val.startswith("s."):
+ ref = val[2:]
+ if f"s.{ref}" in available:
+ return batch.column(f"s.{ref}")
+ for sk, tk in on_pairs:
+ if sk == ref and f"t.{tk}" in available:
+ return batch.column(f"t.{tk}")
+ return pa.nulls(batch.num_rows, type=out_type)
+ if isinstance(val, str) and val.startswith("t."):
+ ref = val[2:]
+ col_name = f"t.{ref}"
+ return batch.column(col_name) if col_name in available else pa.nulls(
+ batch.num_rows, type=out_type
+ )
+ return pa.array([val] * batch.num_rows, type=out_type)
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
new file mode 100644
index 0000000000..6b918264bd
--- /dev/null
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -0,0 +1,470 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+import shutil
+import tempfile
+import unittest
+import uuid
+
+import pyarrow as pa
+import ray
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into
+
+_TEST_NUM_PARTITIONS = 2
+
+
+class RayDataEvolutionMergeIntoTest(unittest.TestCase):
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ])
+
+ de_options = {
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true',
+ }
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog_options = {'warehouse': cls.warehouse}
+ cls.catalog = CatalogFactory.create(cls.catalog_options)
+ cls.catalog.create_database('default', True)
+ if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True, num_cpus=2)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ if ray.is_initialized():
+ ray.shutdown()
+ except Exception:
+ pass
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _create_table(self, options=None):
+ opts = options if options is not None else self.de_options
+ name = f'default.tbl_{uuid.uuid4().hex[:8]}'
+ s = Schema.from_pyarrow_schema(self.pa_schema, options=opts)
+ self.catalog.create_table(name, s, False)
+ return name
+
+ def _source(self, ids=(1,)):
+ return pa.Table.from_pydict(
+ {
+ 'id': pa.array(list(ids), type=pa.int32()),
+ 'name': ['x'] * len(ids),
+ 'age': [10] * len(ids),
+ },
+ schema=self.pa_schema,
+ )
+
+ def _write(self, target, data):
+ table = self.catalog.get_table(target)
+ wb = table.new_batch_write_builder()
+ writer = wb.new_write()
+ writer.write_arrow(data)
+ wb.new_commit().commit(writer.prepare_commit())
+ writer.close()
+
+ def _read_sorted(self, target):
+ table = self.catalog.get_table(target)
+ rb = table.new_read_builder()
+ splits = rb.new_scan().plan().splits()
+ return rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
+
+ def _snapshot_id(self, target):
+ table = self.catalog.get_table(target)
+ snap = table.snapshot_manager().get_latest_snapshot()
+ return snap.id if snap is not None else None
+
+ def test_no_clause_raises(self):
+ target = self._create_table()
+ with self.assertRaises(ValueError):
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ def test_non_de_table_rejected(self):
+ target = self._create_table(options={'row-tracking.enabled': 'true'})
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('data-evolution.enabled', str(ctx.exception))
+
+ def test_no_row_tracking_rejected(self):
+ target = self._create_table(options={'data-evolution.enabled': 'true'})
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('row-tracking.enabled', str(ctx.exception))
+
+ def test_source_missing_on_col_raises(self):
+ target = self._create_table()
+ bad_source = pa.Table.from_pydict(
+ {'name': ['x'], 'age': [10]},
+ schema=pa.schema([('name', pa.string()), ('age', pa.int32())]),
+ )
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=bad_source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn("'id'", str(ctx.exception))
+
+ def test_matched_update_star(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3, 4], type=pa.int32()),
+ 'name': ['b2', 'c2', 'd'],
+ 'age': pa.array([22, 33, 40], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c2'])
+ self.assertEqual(out['age'], [10, 22, 33])
+
+ def test_not_matched_insert_appends_unmatched(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3, 4], type=pa.int32()),
+ 'name': ['b2', 'c2', 'd'],
+ 'age': pa.array([22, 33, 40], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3, 4])
+ self.assertEqual(out['name'], ['a', 'b', 'c', 'd'])
+ self.assertEqual(out['age'], [10, 20, 30, 40])
+
+ def test_combined_update_and_insert(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ metrics = merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c'])
+ self.assertEqual(out['age'], [10, 22, 30])
+ self.assertEqual(metrics, {
+ 'num_matched': 1, 'num_inserted': 1, 'num_unchanged': 0,
+ })
+
+ def test_on_with_renamed_columns_star(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source_schema = pa.schema([
+ ('uid', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ])
+ source = pa.Table.from_pydict(
+ {
+ 'uid': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=source_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on={'id': 'uid'},
+ when_matched=[WhenMatched(update='*')],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c'])
+ self.assertEqual(out['age'], [10, 22, 30])
+
+ def test_insert_into_empty_target(self):
+ target = self._create_table()
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b', 'c'])
+ self.assertEqual(out['age'], [10, 20, 30])
+
+ def test_multi_source_match_raises_by_default(self):
+ # One target row matched by several source rows: the winning value is
+ # undefined (Spark DE's checkCardinality=false), so we refuse by
default.
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 1], type=pa.int32()),
+ 'name': ['x', 'y'],
+ 'age': pa.array([100, 200], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ with self.assertRaises(Exception) as ctx:
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn("multiple source rows", str(ctx.exception))
+
+ def test_blob_columns_excluded(self):
+ import types
+
+ from pypaimon.ray.data_evolution_merge_into import _blob_col_names
+ from pypaimon.schema.data_types import AtomicType, DataField
+
+ fake_table = types.SimpleNamespace(
+ table_schema=types.SimpleNamespace(
+ fields=[
+ DataField(0, 'id', AtomicType('INT')),
+ DataField(1, 'payload', AtomicType('BLOB')),
+ ]
+ )
+ )
+ self.assertEqual({'payload'}, _blob_col_names(fake_table))
+
+ def test_combined_writes_single_snapshot(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+ before = self._snapshot_id(target)
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ after = self._snapshot_id(target)
+ self.assertEqual(after, before + 1)
+
+ def test_empty_target_matched_update_is_noop(self):
+ target = self._create_table()
+ before = self._snapshot_id(target)
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ self.assertEqual(self._snapshot_id(target), before)
+
+
+class TargetProjectionTest(unittest.TestCase):
+
+ def _clause(self, spec):
+ from pypaimon.ray import data_evolution_merge_into as m
+ return m._NormalizedClause(spec=spec)
+
+ def test_unconditional_set_excludes_target_update_col(self):
+ from pypaimon.ray import data_evolution_merge_into as m
+ cols = m._resolve_target_projection(
+ [self._clause({'feature': 's.feature'})],
+ ['id'], ['feature'], ['id', 'feature', 'image'],
+ )
+ self.assertEqual(['id'], cols)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/write/ray_datasink.py
b/paimon-python/pypaimon/write/ray_datasink.py
index 989ec42dc4..60edbc855e 100644
--- a/paimon-python/pypaimon/write/ray_datasink.py
+++ b/paimon-python/pypaimon/write/ray_datasink.py
@@ -135,22 +135,26 @@ class PaimonDatasink(_DatasinkBase):
return commit_messages_list
+ @staticmethod
+ def _extract_write_returns(write_result: Any):
+ """Normalize WriteResult.write_returns (Ray 2.44+) vs list of returns
+ (older Ray) into a list of per-task commit-message lists."""
+ if hasattr(write_result, "write_returns"):
+ return write_result.write_returns
+ if isinstance(write_result, list):
+ return write_result
+ raise TypeError(
+ f"Unexpected write_result type {type(write_result).__name__}: "
+ "expected object with .write_returns or list of commit message "
+ "lists. Refusing to proceed to avoid silent data loss."
+ )
+
def on_write_complete(
self, write_result: Any
):
table_commit = None
try:
- # WriteResult.write_returns (Ray 2.44+); older Ray may pass list
of returns
- if hasattr(write_result, "write_returns"):
- write_returns = write_result.write_returns
- elif isinstance(write_result, list):
- write_returns = write_result
- else:
- raise TypeError(
- f"Unexpected write_result type
{type(write_result).__name__}: "
- "expected object with .write_returns or list of commit
message lists. "
- "Refusing to proceed to avoid silent data loss."
- )
+ write_returns = self._extract_write_returns(write_result)
all_commit_messages = [
commit_message
for commit_messages in write_returns
diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py
b/paimon-python/pypaimon/write/table_update_by_row_id.py
index ac9c68c362..e4c448be14 100644
--- a/paimon-python/pypaimon/write/table_update_by_row_id.py
+++ b/paimon-python/pypaimon/write/table_update_by_row_id.py
@@ -16,6 +16,7 @@
# under the License.
import bisect
+from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import pyarrow as pa
@@ -32,6 +33,21 @@ from pypaimon.write.commit_message import CommitMessage
from pypaimon.write.file_store_write import FileStoreWrite
+@dataclass(frozen=True)
+class _FilesInfo:
+ """Snapshot view of target data files keyed by first_row_id.
+
+ Built once per merge by the driver and broadcast to workers so each task
+ avoids re-scanning the manifest.
+ """
+ snapshot_id: int
+ first_row_ids: List[int]
+ first_row_id_index: Dict[int, Tuple[DataSplit, List[DataFileMeta]]] = (
+ field(default_factory=dict)
+ )
+ total_row_count: int = 0
+
+
class TableUpdateByRowId:
"""
Table update for partial column updates (data evolution).
@@ -42,32 +58,40 @@ class TableUpdateByRowId:
FIRST_ROW_ID_COLUMN = '_FIRST_ROW_ID'
- def __init__(self, table, commit_user: str, commit_identifier: int):
+ def __init__(
+ self, table, commit_user: str, commit_identifier: int,
+ _precomputed_files_info: Optional[_FilesInfo] = None,
+ ):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.commit_user = commit_user
self.commit_identifier = commit_identifier
- # Snapshot the current state once: a single ``first_row_id -> (split,
files)``
- # map is enough to drive every downstream lookup (partition,
row-count, read).
- (self.snapshot_id,
- self.first_row_ids,
- self._first_row_id_index,
- self.total_row_count) = self._load_existing_files_info()
+ info = _precomputed_files_info or self._load_existing_files_info()
+ self.snapshot_id = info.snapshot_id
+ self.first_row_ids = info.first_row_ids
+ self._first_row_id_index = info.first_row_id_index
+ self.total_row_count = info.total_row_count
self.commit_messages: List[CommitMessage] = []
- def _load_existing_files_info(
- self,
- ) -> Tuple[int, List[int], Dict[int, Tuple[DataSplit,
List[DataFileMeta]]], int]:
+ def _snapshot_files_info(self) -> _FilesInfo:
+ """Internal: return the current snapshot's file index for broadcast."""
+ return _FilesInfo(
+ snapshot_id=self.snapshot_id,
+ first_row_ids=self.first_row_ids,
+ first_row_id_index=self._first_row_id_index,
+ total_row_count=self.total_row_count,
+ )
+
+ def _load_existing_files_info(self) -> _FilesInfo:
"""Scan the latest snapshot once and index files by ``first_row_id``.
- Returns:
- A 4-tuple of ``(snapshot_id, sorted_unique_first_row_ids, index,
total_row_count)``
- where ``index`` maps each ``first_row_id`` to the owning split and
- the list of files with that id (a single id may belong to multiple
- files when data evolution has split a logical row range).
+ Returns a :class:`_FilesInfo` whose ``first_row_id_index`` maps each
+ ``first_row_id`` to the owning split and the list of files with that
+ id (a single id may belong to multiple files when data evolution has
+ split a logical row range).
"""
plan = self.table.new_read_builder().new_scan().plan()
splits = plan.splits()
@@ -95,7 +119,12 @@ class TableUpdateByRowId:
total_row_count = 0
snapshot_id = plan.snapshot_id if plan.snapshot_id is not None else -1
- return snapshot_id, sorted(index.keys()), index, total_row_count
+ return _FilesInfo(
+ snapshot_id=snapshot_id,
+ first_row_ids=sorted(index.keys()),
+ first_row_id_index=index,
+ total_row_count=total_row_count,
+ )
def update_columns(self, data: pa.Table, column_names: List[str]) ->
List[CommitMessage]:
"""