This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 219cecd46e [python][ray] Support multi-clause fall-through in 
merge_into (#8115)
219cecd46e is described below

commit 219cecd46eb8b60ca8a8c1cf8cbb3c5d983d1fde
Author: XiaoHongbo <[email protected]>
AuthorDate: Sat Jun 6 16:43:07 2026 +0800

    [python][ray] Support multi-clause fall-through in merge_into (#8115)
---
 docs/docs/pypaimon/ray-data.md                     |   7 +
 .../pypaimon/ray/data_evolution_merge_into.py      |  15 +-
 .../pypaimon/ray/data_evolution_merge_join.py      | 137 +++++----
 .../tests/ray_data_evolution_merge_into_test.py    | 333 +++++++++++++++++++++
 4 files changed, 432 insertions(+), 60 deletions(-)

diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index d160b4302f..a6d98f8fcc 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -386,6 +386,13 @@ columns (`s.*`). Requires the `datafusion` package: `pip 
install pypaimon[sql]`.
   Use `lit()` for literals starting with `s.` or `t.`.
 - `condition`: an optional SQL-style boolean expression. Use `s.<col>` and
   `t.<col>` to reference source and target columns.
+- Multiple clauses are evaluated in order; the first matching condition wins:
+  ```python
+  when_matched=[
+      WhenMatched(update="*", condition="s.ts > t.ts"),
+      WhenMatched(update="*"),  # fallback for unmatched rows
+  ]
+  ```
 
 **Parameters:**
 - `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
index fa824b44a2..cbfcef907d 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -93,12 +93,15 @@ def _prepare(target, source, catalog_options, when_matched, 
when_not_matched, on
         raise ValueError(
             "At least one of when_matched or when_not_matched must be 
non-empty."
         )
-    if len(when_matched) > 1 or len(when_not_matched) > 1:
-        raise NotImplementedError(
-            "merge_into currently supports a single WhenMatched and a single "
-            "WhenNotMatched clause; multi-clause fall-through will be added "
-            "in a follow-up PR."
-        )
+    for label, clauses in [("when_matched", when_matched),
+                           ("when_not_matched", when_not_matched)]:
+        for i, clause in enumerate(clauses[:-1]):
+            if clause.condition is None:
+                raise ValueError(
+                    f"Only the last {label} clause may omit its condition. "
+                    f"Clause at index {i} has no condition, making subsequent "
+                    f"clauses unreachable."
+                )
     target_on_cols, source_on_cols = _normalize_on(on)
 
     from pypaimon.catalog.catalog_factory import CatalogFactory
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
index f01f9b59aa..14088979f8 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
@@ -87,43 +87,57 @@ def build_matched_update_ds(
         right_on=tuple(f"s.{c}" for c in source_on),
     )
 
-    # MVP supports a single matched clause; future fan-out (conditions, multi-
-    # clause fall-through) must thread every clause's spec through the
-    # transform — guard so silent first-only behaviour can't sneak in.
-    assert len(clauses) == 1, (
-        f"build_matched_update_ds expected 1 clause, got {len(clauses)}"
-    )
-    spec = clauses[0].spec
-    condition = clauses[0].condition
     captured_update_cols = list(update_cols)
     captured_row_id_name = row_id_name
     captured_on_pairs = list(zip(source_on, target_on))
     captured_schema = update_schema
 
-    captured_apply = None
-    captured_rewritten = None
-    if condition is not None:
-        from pypaimon.ray.merge_condition import (
-            apply_condition, remap_source_on_keys, rewrite_condition,
-        )
-        on_map = dict(zip(source_on, target_on))
-        captured_rewritten = remap_source_on_keys(
-            rewrite_condition(condition), on_map,
-        )
-        captured_apply = apply_condition
+    on_map = dict(zip(source_on, target_on))
+    prepared_clauses = []
+    for clause in clauses:
+        rewritten = None
+        if clause.condition is not None:
+            from pypaimon.ray.merge_condition import (
+                remap_source_on_keys, rewrite_condition,
+            )
+            rewritten = remap_source_on_keys(
+                rewrite_condition(clause.condition), on_map,
+            )
+        prepared_clauses.append((clause.spec, rewritten))
+
+    _filter_batch = None
+    if any(r is not None for _, r in prepared_clauses):
+        from pypaimon.ray.merge_condition import filter_batch as _filter_batch
 
     def _transform(batch: pa.Table) -> pa.Table:
-        if captured_apply is not None:
-            batch = captured_apply(
-                batch, captured_rewritten, captured_schema,
-            )
-            if batch.num_rows == 0:
-                return batch
-        return vectorized_matched_transform(
-            batch, spec, captured_on_pairs,
-            captured_update_cols, captured_row_id_name,
-            captured_schema,
-        )
+        remaining = batch
+        parts = []
+        for spec, rewritten in prepared_clauses:
+            if remaining.num_rows == 0:
+                break
+            if rewritten is not None:
+                matched = _filter_batch(
+                    remaining, rewritten, _pre_rewritten=True,
+                )
+            else:
+                matched = remaining
+            if matched.num_rows == 0:
+                continue
+            parts.append(vectorized_matched_transform(
+                matched, spec, captured_on_pairs,
+                captured_update_cols, captured_row_id_name,
+                captured_schema,
+            ))
+            if rewritten is not None and matched.num_rows < remaining.num_rows:
+                not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
+                remaining = _filter_batch(
+                    remaining, not_cond, _pre_rewritten=True,
+                )
+            else:
+                remaining = remaining.slice(0, 0)
+        if not parts:
+            return captured_schema.empty_table()
+        return pa.concat_tables(parts)
 
     return joined.map_batches(_transform, **_map_kwargs(ray_remote_args))
 
@@ -324,32 +338,47 @@ def build_not_matched_insert_ds(
             right_on=tuple(f"t.{c}" for c in target_on),
         )
 
-    # MVP supports a single not-matched clause; see build_matched_update_ds
-    # for why we assert instead of silently dropping the rest.
-    assert len(clauses) == 1, (
-        f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}"
-    )
-    spec = clauses[0].spec
-    condition = clauses[0].condition
-    captured_apply = None
-    captured_rewritten = None
-    if condition is not None:
-        from pypaimon.ray.merge_condition import apply_condition, 
rewrite_condition
-        captured_rewritten = rewrite_condition(condition)
-        captured_apply = apply_condition
+    prepared_clauses = []
+    for clause in clauses:
+        rewritten = None
+        if clause.condition is not None:
+            from pypaimon.ray.merge_condition import rewrite_condition
+            rewritten = rewrite_condition(clause.condition)
+        prepared_clauses.append((clause.spec, rewritten))
+
+    _filter_batch_nm = None
+    if any(r is not None for _, r in prepared_clauses):
+        from pypaimon.ray.merge_condition import filter_batch as 
_filter_batch_nm
 
     def _transform(batch: pa.Table) -> pa.Table:
-        if captured_apply is not None:
-            batch = captured_apply(
-                batch, captured_rewritten, out_schema,
-            )
-            if batch.num_rows == 0:
-                return _coerce_large_string_types(batch)
-        return _coerce_large_string_types(
-            vectorized_insert_transform(
-                batch, spec, captured_field_names, out_schema
-            )
-        )
+        remaining = batch
+        parts = []
+        for spec, rewritten in prepared_clauses:
+            if remaining.num_rows == 0:
+                break
+            if rewritten is not None:
+                matched = _filter_batch_nm(
+                    remaining, rewritten, _pre_rewritten=True,
+                )
+                if matched.num_rows > 0:
+                    parts.append(vectorized_insert_transform(
+                        matched, spec, captured_field_names, out_schema
+                    ))
+                if matched.num_rows < remaining.num_rows:
+                    not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
+                    remaining = _filter_batch_nm(
+                        remaining, not_cond, _pre_rewritten=True,
+                    )
+                else:
+                    remaining = remaining.slice(0, 0)
+            else:
+                parts.append(vectorized_insert_transform(
+                    remaining, spec, captured_field_names, out_schema
+                ))
+                remaining = remaining.slice(0, 0)
+        if not parts:
+            return _coerce_large_string_types(out_schema.empty_table())
+        return _coerce_large_string_types(pa.concat_tables(parts))
 
     return unmatched.map_batches(
         _transform, **_map_kwargs(ray_remote_args)
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py 
b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
index ca06d43e53..b54eeb5cf0 100644
--- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -153,6 +153,40 @@ class RayDataEvolutionMergeIntoTest(unittest.TestCase):
                 num_partitions=_TEST_NUM_PARTITIONS,
             )
 
+    def test_unconditional_non_last_matched_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[
+                    WhenMatched(update='*'),
+                    WhenMatched(update={'age': 's.age'}, condition='s.age > 
10'),
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('when_matched', str(ctx.exception))
+        self.assertIn('unreachable', str(ctx.exception))
+
+    def test_unconditional_non_last_not_matched_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_not_matched=[
+                    WhenNotMatched(insert='*'),
+                    WhenNotMatched(insert='*', condition='s.age > 10'),
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('when_not_matched', str(ctx.exception))
+        self.assertIn('unreachable', str(ctx.exception))
+
     def test_non_de_table_rejected(self):
         target = self._create_table(options={'row-tracking.enabled': 'true'})
         with self.assertRaises(ValueError) as ctx:
@@ -1315,6 +1349,305 @@ class RayDataEvolutionMergeIntoTest(unittest.TestCase):
         self.assertEqual(out['name'], ['keep'])
         self.assertEqual(out['age'], [99])
 
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_matched_clause_fall_through(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2, 3], type=pa.int32()),
+                    'name': ['a', 'b', 'c'],
+                    'age': pa.array([10, 20, 30], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2, 3], type=pa.int32()),
+                'name': ['a2', 'b2', 'c2'],
+                'age': pa.array([99, 88, 77], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[
+                WhenMatched(update='*', condition='s.age > 80'),
+                WhenMatched(update='*'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+        self.assertEqual(out['name'], ['a2', 'b2', 'c2'])
+        self.assertEqual(out['age'], [99, 88, 77])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_not_matched_clause_fall_through(self):
+        target = self._create_table()
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2, 3], type=pa.int32()),
+                'name': ['a', 'b', 'c'],
+                'age': pa.array([25, 15, 5], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_not_matched=[
+                WhenNotMatched(insert='*', condition='s.age >= 20'),
+                WhenNotMatched(insert='*'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_matched_null_falls_through(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2, 3], type=pa.int32()),
+                    'name': ['a', 'b', 'c'],
+                    'age': pa.array([10, 20, 30], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2, 3], type=pa.int32()),
+                'name': ['a2', 'b2', 'c2'],
+                'age': pa.array([None, 50, 60], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[
+                WhenMatched(update='*', condition='s.age > 40'),
+                WhenMatched(update='*'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+        self.assertEqual(out['name'], ['a2', 'b2', 'c2'])
+        self.assertEqual(out['age'], [None, 50, 60])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_not_matched_null_falls_through(self):
+        target = self._create_table()
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2], type=pa.int32()),
+                'name': ['a', 'b'],
+                'age': pa.array([None, 25], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_not_matched=[
+                WhenNotMatched(insert='*', condition='s.age > 20'),
+                WhenNotMatched(insert='*'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2])
+        self.assertEqual(out['age'], [None, 25])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_clause_no_match_skipped(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2], type=pa.int32()),
+                    'name': ['a', 'b'],
+                    'age': pa.array([10, 20], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2], type=pa.int32()),
+                'name': ['a2', 'b2'],
+                'age': pa.array([5, 5], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[
+                WhenMatched(update='*', condition='s.age > 50'),
+                WhenMatched(update='*', condition='s.age > 30'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['a', 'b'])
+        self.assertEqual(out['age'], [10, 20])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_clause_first_wins(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['old'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1], type=pa.int32()),
+                'name': ['first'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[
+                WhenMatched(update={'name': 's.name'},
+                            condition='s.age > 50'),
+                WhenMatched(update={'age': 's.age'},
+                            condition='s.age > 10'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['first'])
+        self.assertEqual(out['age'], [10])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_clause_duplicate_source_one_actionable(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['a'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 1], type=pa.int32()),
+                'name': ['x', 'y'],
+                'age': pa.array([99, 5], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[
+                WhenMatched(update='*', condition='s.age > 50'),
+                WhenMatched(update='*', condition='s.age > 80'),
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['x'])
+        self.assertEqual(out['age'], [99])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_multi_clause_duplicate_both_actionable_raises(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['a'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 1], type=pa.int32()),
+                'name': ['x', 'y'],
+                'age': pa.array([99, 50], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        with self.assertRaises(Exception) as ctx:
+            merge_into(
+                target=target,
+                source=source,
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[
+                    WhenMatched(update='*', condition='s.age > 80'),
+                    WhenMatched(update='*', condition='s.age > 30'),
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('multiple source rows', str(ctx.exception))
+
 
 class TargetProjectionTest(unittest.TestCase):
 

Reply via email to