This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new e4d0573aed [python][ray] Ray merge into support condition (#8076)
e4d0573aed is described below

commit e4d0573aed02e341bb8fc6411a5280d7ed4db2b5
Author: XiaoHongbo <[email protected]>
AuthorDate: Wed Jun 3 19:22:19 2026 +0800

    [python][ray] Ray merge into support condition (#8076)
---
 .github/workflows/paimon-python-checks.yml         |   2 +-
 docs/docs/pypaimon/ray-data.md                     |  28 +-
 paimon-python/dev/requirements-dev.txt             |   2 +
 .../pypaimon/ray/data_evolution_merge_into.py      |  61 +++-
 .../pypaimon/ray/data_evolution_merge_join.py      |  32 ++
 .../pypaimon/ray/data_evolution_merge_transform.py |   1 +
 paimon-python/pypaimon/ray/merge_condition.py      | 104 ++++++
 .../tests/ray_data_evolution_merge_into_test.py    | 385 ++++++++++++++++++++-
 8 files changed, 598 insertions(+), 17 deletions(-)

diff --git a/.github/workflows/paimon-python-checks.yml 
b/.github/workflows/paimon-python-checks.yml
index 5b918f5829..f27219e6fa 100755
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -133,7 +133,7 @@ jobs:
           else
             python -m pip install --upgrade pip
             pip install torch --index-url https://download.pytorch.org/whl/cpu
-            python -m pip install pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0
+            python -m pip install pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0 
'datafusion>=52'
             python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION 
}}' -i https://pypi.org/simple/
             if python -c "import sys; sys.exit(0 if sys.version_info >= (3, 
11) else 1)"; then
               python -m pip install vortex-data==0.70.0
diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index 658a1098ae..a3c59c9f68 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -357,12 +357,29 @@ metrics = merge_into(
 print(metrics)   # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0}
 ```
 
+Conditional clauses filter which matched/unmatched rows are acted on:
+
+```python
+merge_into(
+    target="db.table",
+    source=source_ds,
+    catalog_options=catalog_options,
+    on=["id"],
+    when_matched=[WhenMatched(update="*", condition="s.age > t.age")],
+    when_not_matched=[WhenNotMatched(insert="*", condition="s.age > 18")],
+)
+```
+
+Conditions use SQL-style expressions with `s.` (source) and `t.` (target)
+column prefixes. `WhenNotMatched` conditions may only reference source
+columns (`s.*`). Requires the `datafusion` package: `pip install 
pypaimon[sql]`.
+
 - `update` / `insert`: only `"*"` is supported in this PR. A future follow-up
   will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are
   analyzable string expressions (`"s.<col>"`, `"t.<col>"`, or literals),
   not Python callables.
-- `condition`: reserved for a future follow-up; passing a non-None value
-  currently raises `NotImplementedError`.
+- `condition`: an optional SQL-style boolean expression. Use `s.<col>` and
+  `t.<col>` to reference source and target columns.
 
 **Parameters:**
 - `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a
@@ -375,10 +392,9 @@ print(metrics)   # {"num_matched": 3, "num_inserted": 2, 
"num_unchanged": 0}
   tasks (update transform, group write, insert transform).
 - `concurrency`: scheduling for the insert sink.
 
-**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. In this PR
-every matched row is updated, so `num_matched` always equals `num_updated`
-and `num_unchanged` is always `0`; conditional clauses (added later) can
-make `num_unchanged > 0`.
+**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. `num_matched`
+counts the rows actually updated (after condition filtering). `num_unchanged`
+is `0` in the current implementation.
 
 **Notes:**
 - Partition key columns cannot be updated by matched clauses. If the target
diff --git a/paimon-python/dev/requirements-dev.txt 
b/paimon-python/dev/requirements-dev.txt
index 9ef88817f7..c83a2e44b8 100644
--- a/paimon-python/dev/requirements-dev.txt
+++ b/paimon-python/dev/requirements-dev.txt
@@ -28,5 +28,7 @@ requests
 parameterized
 # Vortex 0.71.0 regresses native predicate pushdown on single-row files.
 vortex-data==0.70.0; python_version >= "3.11"
+# merge_into condition expressions (optional, for condition tests)
+datafusion>=52; python_version >= "3.10"
 # Lumina vector search (optional, for lumina index tests)
 lumina-data>=0.1.0
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
index b90abdc745..5be68f301e 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -96,12 +96,6 @@ def _prepare(target, source, catalog_options, when_matched, 
when_not_matched, on
             "WhenNotMatched clause; multi-clause fall-through will be added "
             "in a follow-up PR."
         )
-    for clause in list(when_matched) + list(when_not_matched):
-        if clause.condition is not None:
-            raise NotImplementedError(
-                "merge_into does not yet support condition expressions; "
-                "this will be added in a follow-up PR."
-            )
     target_on_cols, source_on_cols = _normalize_on(on)
 
     from pypaimon.catalog.catalog_factory import CatalogFactory
@@ -136,14 +130,41 @@ def _prepare(target, source, catalog_options, 
when_matched, when_not_matched, on
             spec=_normalize_set_spec(
                 c.update, settable_field_names, on_map,
             ),
+            condition=c.condition,
         )
         for c in when_matched
     ]
+    has_condition = any(
+        c.condition is not None
+        for c in list(when_matched) + list(when_not_matched)
+    )
+    if has_condition:
+        from pypaimon.ray.merge_condition import (
+            _require_datafusion, extract_target_columns,
+        )
+        _require_datafusion()
+        for c in when_not_matched:
+            if c.condition is not None:
+                t_refs = extract_target_columns(c.condition)
+                if t_refs:
+                    raise ValueError(
+                        f"WhenNotMatched condition must not reference "
+                        f"target columns (t.*), but found: {sorted(t_refs)}"
+                    )
+        for c in list(when_matched) + list(when_not_matched):
+            if c.condition is not None:
+                blob_refs = extract_target_columns(c.condition) & blob_cols
+                if blob_refs:
+                    raise ValueError(
+                        f"condition must not reference blob columns, "
+                        f"but found: {sorted(blob_refs)}"
+                    )
     not_matched_specs = [
         _NormalizedClause(
             spec=_normalize_set_spec(
                 c.insert, settable_field_names, on_map,
             ),
+            condition=c.condition,
         )
         for c in when_not_matched
     ]
@@ -154,6 +175,25 @@ def _prepare(target, source, catalog_options, 
when_matched, when_not_matched, on
         source_ds, settable_field_names, on_map,
     )
 
+    if has_condition:
+        from pypaimon.ray.merge_condition import extract_columns
+        source_names = set(_source_schema_or_raise(source_ds).names)
+        target_names = set(full_target_field_names)
+        for c in list(when_matched) + list(when_not_matched):
+            if c.condition is not None:
+                for ref in extract_columns(c.condition):
+                    prefix, col = ref.split(".", 1)
+                    if prefix == "s" and col not in source_names:
+                        raise ValueError(
+                            f"condition references unknown source "
+                            f"column '{col}'"
+                        )
+                    if prefix == "t" and col not in target_names:
+                        raise ValueError(
+                            f"condition references unknown target "
+                            f"column '{col}'"
+                        )
+
     from pypaimon.schema.data_types import PyarrowFieldParser
     full_pa_schema = PyarrowFieldParser.from_paimon_schema(
         table.table_schema.fields
@@ -272,8 +312,7 @@ def _execute_and_commit(
         tc.commit(all_msgs)
         tc.close()
 
-    # MVP has no condition, so every matched row is updated; num_unchanged
-    # is always 0. Kept in the dict for API stability when condition lands.
+    # num_matched = rows that passed the condition and were updated
     return {
         "num_matched": num_updated,
         "num_inserted": num_inserted,
@@ -375,6 +414,12 @@ def _resolve_target_projection(
     needed = set(_needed_target_cols(
         clauses, target_on, update_cols, target_field_names,
     ))
+    if any(c.condition is not None for c in clauses):
+        from pypaimon.ray.merge_condition import extract_target_columns
+        target_set = set(target_field_names)
+        for clause in clauses:
+            if clause.condition is not None:
+                needed |= extract_target_columns(clause.condition) & target_set
     return [c for c in target_field_names if c in needed]
 
 
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
index f1b814a861..f01f9b59aa 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
@@ -94,12 +94,31 @@ def build_matched_update_ds(
         f"build_matched_update_ds expected 1 clause, got {len(clauses)}"
     )
     spec = clauses[0].spec
+    condition = clauses[0].condition
     captured_update_cols = list(update_cols)
     captured_row_id_name = row_id_name
     captured_on_pairs = list(zip(source_on, target_on))
     captured_schema = update_schema
 
+    captured_apply = None
+    captured_rewritten = None
+    if condition is not None:
+        from pypaimon.ray.merge_condition import (
+            apply_condition, remap_source_on_keys, rewrite_condition,
+        )
+        on_map = dict(zip(source_on, target_on))
+        captured_rewritten = remap_source_on_keys(
+            rewrite_condition(condition), on_map,
+        )
+        captured_apply = apply_condition
+
     def _transform(batch: pa.Table) -> pa.Table:
+        if captured_apply is not None:
+            batch = captured_apply(
+                batch, captured_rewritten, captured_schema,
+            )
+            if batch.num_rows == 0:
+                return batch
         return vectorized_matched_transform(
             batch, spec, captured_on_pairs,
             captured_update_cols, captured_row_id_name,
@@ -311,8 +330,21 @@ def build_not_matched_insert_ds(
         f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}"
     )
     spec = clauses[0].spec
+    condition = clauses[0].condition
+    captured_apply = None
+    captured_rewritten = None
+    if condition is not None:
+        from pypaimon.ray.merge_condition import apply_condition, 
rewrite_condition
+        captured_rewritten = rewrite_condition(condition)
+        captured_apply = apply_condition
 
     def _transform(batch: pa.Table) -> pa.Table:
+        if captured_apply is not None:
+            batch = captured_apply(
+                batch, captured_rewritten, out_schema,
+            )
+            if batch.num_rows == 0:
+                return _coerce_large_string_types(batch)
         return _coerce_large_string_types(
             vectorized_insert_transform(
                 batch, spec, captured_field_names, out_schema
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
index 0fc2d22f77..ed786467f1 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
@@ -40,6 +40,7 @@ class WhenNotMatched:
 @dataclass
 class _NormalizedClause:
     spec: Dict[str, Any]
+    condition: Optional[str] = None
 
 
 def vectorized_matched_transform(
diff --git a/paimon-python/pypaimon/ray/merge_condition.py 
b/paimon-python/pypaimon/ray/merge_condition.py
new file mode 100644
index 0000000000..5497406c5c
--- /dev/null
+++ b/paimon-python/pypaimon/ray/merge_condition.py
@@ -0,0 +1,104 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import re
+from typing import Mapping, Set
+
+import pyarrow as pa
+
+
+_COL_REF_PATTERN = re.compile(r'\b([st])\.(\w+)\b')
+
+
+def _require_datafusion():
+    try:
+        import datafusion
+        return datafusion
+    except ImportError:
+        raise ImportError(
+            "merge_into condition expressions require the 'datafusion' "
+            "package. Install it with: pip install pypaimon[sql]"
+        )
+
+
+_STRING_LITERAL = re.compile(r"'(?:[^']|'')*'")
+
+
+def _strip_string_literals(condition: str) -> str:
+    return _STRING_LITERAL.sub('', condition)
+
+
+def rewrite_condition(condition: str) -> str:
+    parts, last = [], 0
+    for m in _STRING_LITERAL.finditer(condition):
+        parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', 
condition[last:m.start()]))
+        parts.append(m.group())
+        last = m.end()
+    parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:]))
+    return ''.join(parts)
+
+
+def remap_source_on_keys(
+    rewritten: str, on_map: Mapping[str, str],
+) -> str:
+    for s_col, t_col in on_map.items():
+        old, new = f'"s.{s_col}"', f'"t.{t_col}"'
+        parts, last = [], 0
+        for m in _STRING_LITERAL.finditer(rewritten):
+            parts.append(rewritten[last:m.start()].replace(old, new))
+            parts.append(m.group())
+            last = m.end()
+        parts.append(rewritten[last:].replace(old, new))
+        rewritten = ''.join(parts)
+    return rewritten
+
+
+def filter_batch(
+    batch: pa.Table, condition: str, _pre_rewritten: bool = False,
+) -> pa.Table:
+    if batch.num_rows == 0:
+        return batch
+    datafusion = _require_datafusion()
+    rewritten = condition if _pre_rewritten else rewrite_condition(condition)
+    ctx = datafusion.SessionContext()
+    ctx.register_record_batches("_batch", [batch.to_batches()])
+    result = ctx.sql(
+        f'SELECT * FROM _batch WHERE {rewritten}'
+    )
+    return result.to_arrow_table()
+
+
+def apply_condition(
+    batch: pa.Table, rewritten: str, empty_schema: pa.Schema,
+) -> pa.Table:
+    batch = filter_batch(batch, rewritten, _pre_rewritten=True)
+    if batch.num_rows == 0:
+        return empty_schema.empty_table()
+    return batch
+
+
+def extract_columns(condition: str) -> Set[str]:
+    stripped = _strip_string_literals(condition)
+    return {f"{m.group(1)}.{m.group(2)}"
+            for m in _COL_REF_PATTERN.finditer(stripped)}
+
+
+def extract_target_columns(condition: str) -> Set[str]:
+    stripped = _strip_string_literals(condition)
+    return {m.group(2) for m in _COL_REF_PATTERN.finditer(stripped)
+            if m.group(1) == "t"}
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py 
b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
index 727185a2e4..47981088f2 100644
--- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -28,6 +28,15 @@ import ray
 from pypaimon import CatalogFactory, Schema
 from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into
 
+try:
+    import datafusion  # noqa: F401
+    _HAS_DATAFUSION = True
+except ImportError:
+    _HAS_DATAFUSION = False
+
+_SKIP_CONDITION = not _HAS_DATAFUSION
+_SKIP_REASON = "datafusion not installed"
+
 _TEST_NUM_PARTITIONS = 2
 
 
@@ -153,6 +162,56 @@ class RayDataEvolutionMergeIntoTest(unittest.TestCase):
             )
         self.assertIn("'id'", str(ctx.exception))
 
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_not_matched_condition_rejects_target_refs(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_not_matched=[
+                    WhenNotMatched(insert='*', condition='t.age > 10')
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('t.', str(ctx.exception))
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_condition_unknown_source_col_rejected(self):
+        target = self._create_table()
+        self._write(target, self._source())
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[
+                    WhenMatched(update='*', condition='s.nonexistent > 0')
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('nonexistent', str(ctx.exception))
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_condition_unknown_target_col_rejected(self):
+        target = self._create_table()
+        self._write(target, self._source())
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[
+                    WhenMatched(update='*', condition='s.age > t.nonexistent')
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('nonexistent', str(ctx.exception))
+
     def test_matched_update_star(self):
         target = self._create_table()
         self._write(
@@ -519,12 +578,249 @@ class RayDataEvolutionMergeIntoTest(unittest.TestCase):
         self.assertEqual(out['id'], [1, 2])
         self.assertEqual(out['pt'], ['a', 'b'])
 
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_matched_update_with_condition(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2, 3], type=pa.int32()),
+                    'name': ['a', 'b', 'c'],
+                    'age': pa.array([10, 20, 30], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2, 3], type=pa.int32()),
+                'name': ['a2', 'b2', 'c2'],
+                'age': pa.array([15, 25, 45], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update='*', condition='s.age > t.age + 
10')],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+        self.assertEqual(out['name'], ['a', 'b', 'c2'])
+        self.assertEqual(out['age'], [10, 20, 45])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_matched_condition_with_source_on_key(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2, 3], type=pa.int32()),
+                    'name': ['a', 'b', 'c'],
+                    'age': pa.array([10, 20, 30], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2, 3], type=pa.int32()),
+                'name': ['a2', 'b2', 'c2'],
+                'age': pa.array([15, 25, 35], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update='*', condition='s.id >= 2')],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+        self.assertEqual(out['name'], ['a', 'b2', 'c2'])
+        self.assertEqual(out['age'], [10, 25, 35])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_not_matched_insert_with_condition(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['a'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([2, 3, 4], type=pa.int32()),
+                'name': ['b', 'c', 'd'],
+                'age': pa.array([15, 25, 5], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_not_matched=[
+                WhenNotMatched(insert='*', condition='s.age >= 10')
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+        self.assertEqual(out['name'], ['a', 'b', 'c'])
+        self.assertEqual(out['age'], [10, 15, 25])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_combined_with_conditions(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2], type=pa.int32()),
+                    'name': ['a', 'b'],
+                    'age': pa.array([10, 20], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2, 3, 4], type=pa.int32()),
+                'name': ['a2', 'b2', 'c', 'd'],
+                'age': pa.array([50, 5, 30, 8], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        metrics = merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update='*', condition='s.age > t.age')],
+            when_not_matched=[
+                WhenNotMatched(insert='*', condition='s.age > 10')
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2, 3])
+        self.assertEqual(out['name'], ['a2', 'b', 'c'])
+        self.assertEqual(out['age'], [50, 20, 30])
+        self.assertEqual(metrics['num_matched'], 1)
+        self.assertEqual(metrics['num_inserted'], 1)
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_condition_no_rows_match_is_noop(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2], type=pa.int32()),
+                    'name': ['a', 'b'],
+                    'age': pa.array([10, 20], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2], type=pa.int32()),
+                'name': ['a2', 'b2'],
+                'age': pa.array([5, 5], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update='*', condition='s.age > t.age')],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2])
+        self.assertEqual(out['name'], ['a', 'b'])
+        self.assertEqual(out['age'], [10, 20])
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_duplicate_source_filtered_by_condition(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['a'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 1], type=pa.int32()),
+                'name': ['x', 'y'],
+                'age': pa.array([5, 20], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[
+                WhenMatched(update='*', condition='s.age > t.age')
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1])
+        self.assertEqual(out['name'], ['y'])
+        self.assertEqual(out['age'], [20])
+
 
 class TargetProjectionTest(unittest.TestCase):
 
-    def _clause(self, spec):
+    def _clause(self, spec, condition=None):
         from pypaimon.ray import data_evolution_merge_into as m
-        return m._NormalizedClause(spec=spec)
+        return m._NormalizedClause(spec=spec, condition=condition)
 
     def test_unconditional_set_excludes_target_update_col(self):
         from pypaimon.ray import data_evolution_merge_into as m
@@ -534,6 +830,91 @@ class TargetProjectionTest(unittest.TestCase):
         )
         self.assertEqual(['id'], cols)
 
+    def test_condition_adds_referenced_target_cols(self):
+        from pypaimon.ray import data_evolution_merge_into as m
+        cols = m._resolve_target_projection(
+            [self._clause({'feature': 's.feature'}, condition='s.age > 
t.age')],
+            ['id'], ['feature'], ['id', 'feature', 'age', 'image'],
+        )
+        self.assertIn('age', cols)
+        self.assertIn('id', cols)
+
+
+class MergeConditionUnitTest(unittest.TestCase):
+
+    def test_rewrite_condition(self):
+        from pypaimon.ray.merge_condition import rewrite_condition
+        self.assertEqual(
+            rewrite_condition('s.age > t.age + 10'),
+            '"s.age" > "t.age" + 10',
+        )
+
+    def test_rewrite_condition_preserves_string_literals(self):
+        from pypaimon.ray.merge_condition import rewrite_condition
+        self.assertEqual(
+            rewrite_condition("s.status = 't.active' AND s.age > t.age"),
+            '"s.status" = \'t.active\' AND "s.age" > "t.age"',
+        )
+
+    def test_remap_source_on_keys(self):
+        from pypaimon.ray.merge_condition import (
+            remap_source_on_keys, rewrite_condition,
+        )
+        rewritten = rewrite_condition('s.id > 1 AND s.age > t.age')
+        remapped = remap_source_on_keys(rewritten, {'id': 'id'})
+        self.assertEqual(remapped, '"t.id" > 1 AND "s.age" > "t.age"')
+
+    def test_remap_source_on_keys_renamed(self):
+        from pypaimon.ray.merge_condition import (
+            remap_source_on_keys, rewrite_condition,
+        )
+        rewritten = rewrite_condition('s.uid > 1')
+        remapped = remap_source_on_keys(rewritten, {'uid': 'id'})
+        self.assertEqual(remapped, '"t.id" > 1')
+
+    def test_remap_preserves_string_literals(self):
+        from pypaimon.ray.merge_condition import (
+            remap_source_on_keys, rewrite_condition,
+        )
+        rewritten = rewrite_condition("s.note = '\"s.id\"' AND s.id = 1")
+        remapped = remap_source_on_keys(rewritten, {'id': 'id'})
+        self.assertEqual(
+            remapped,
+            '"s.note" = \'\"s.id\"\' AND "t.id" = 1',
+        )
+
+    def test_extract_target_columns(self):
+        from pypaimon.ray.merge_condition import extract_target_columns
+        self.assertEqual(
+            extract_target_columns('s.name = t.name AND s.age > t.age'),
+            {'name', 'age'},
+        )
+
+    def test_extract_target_columns_ignores_string_literals(self):
+        from pypaimon.ray.merge_condition import extract_target_columns
+        self.assertEqual(
+            extract_target_columns("s.name = 't.fake' AND s.age > t.age"),
+            {'age'},
+        )
+
+    def test_extract_columns(self):
+        from pypaimon.ray.merge_condition import extract_columns
+        self.assertEqual(
+            extract_columns('s.id = t.id AND s.age > t.age'),
+            {'s.id', 't.id', 's.age', 't.age'},
+        )
+
+    @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+    def test_filter_batch(self):
+        from pypaimon.ray.merge_condition import filter_batch
+        batch = pa.table({
+            's.id': pa.array([1, 2, 3], type=pa.int32()),
+            's.age': pa.array([10, 25, 30], type=pa.int32()),
+            't.age': pa.array([20, 20, 20], type=pa.int32()),
+        })
+        result = filter_batch(batch, 's.age > t.age')
+        self.assertEqual(result.column('s.id').to_pylist(), [2, 3])
+
 
 if __name__ == '__main__':
     unittest.main()

Reply via email to