This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new d78babb99f [python][ray] Support partial SET and INSERT in merge_into 
(#8085)
d78babb99f is described below

commit d78babb99fe7d25a72b002624ee0570a5a2c47c9
Author: XiaoHongbo <[email protected]>
AuthorDate: Thu Jun 4 13:51:13 2026 +0800

    [python][ray] Support partial SET and INSERT in merge_into (#8085)
---
 docs/docs/pypaimon/ray-data.md                     |  14 +-
 paimon-python/pypaimon/ray/__init__.py             |   8 +
 .../pypaimon/ray/data_evolution_merge_into.py      | 108 +++--
 .../pypaimon/ray/data_evolution_merge_transform.py |  40 +-
 .../tests/ray_data_evolution_merge_into_test.py    | 471 ++++++++++++++++++++-
 5 files changed, 604 insertions(+), 37 deletions(-)

diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index a3c59c9f68..d160b4302f 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -374,10 +374,16 @@ Conditions use SQL-style expressions with `s.` (source) 
and `t.` (target)
 column prefixes. `WhenNotMatched` conditions may only reference source
 columns (`s.*`). Requires the `datafusion` package: `pip install 
pypaimon[sql]`.
 
-- `update` / `insert`: only `"*"` is supported in this PR. A future follow-up
-  will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are
-  analyzable string expressions (`"s.<col>"`, `"t.<col>"`, or literals),
-  not Python callables.
+- `update` / `insert`: `"*"` updates/inserts all non-blob columns from source.
+  A mapping selects specific columns:
+  ```python
+  from pypaimon.ray import source_col, target_col, lit
+
+  WhenMatched(update={"age": source_col("age"), "name": target_col("name")})
+  WhenNotMatched(insert={"id": source_col("id"), "status": lit("new")})
+  ```
+  `"s.<col>"` / `"t.<col>"` shorthands also work (`t.*` only in update).
+  Use `lit()` for literals starting with `s.` or `t.`.
 - `condition`: an optional SQL-style boolean expression. Use `s.<col>` and
   `t.<col>` to reference source and target columns.
 
diff --git a/paimon-python/pypaimon/ray/__init__.py 
b/paimon-python/pypaimon/ray/__init__.py
index 9161f3cbb3..4280187956 100644
--- a/paimon-python/pypaimon/ray/__init__.py
+++ b/paimon-python/pypaimon/ray/__init__.py
@@ -21,6 +21,11 @@ from pypaimon.ray.data_evolution_merge_into import (
     WhenNotMatched,
     merge_into,
 )
+from pypaimon.ray.data_evolution_merge_transform import (
+    source_col,
+    target_col,
+    lit,
+)
 
 __all__ = [
     "read_paimon",
@@ -28,4 +33,7 @@ __all__ = [
     "merge_into",
     "WhenMatched",
     "WhenNotMatched",
+    "source_col",
+    "target_col",
+    "lit",
 ]
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
index 379655ec77..fa824b44a2 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -30,8 +30,11 @@ from pypaimon.ray.data_evolution_merge_join import (
     distributed_write_collect_msgs,
 )
 from pypaimon.ray.data_evolution_merge_transform import (
+    LiteralValue,
     OnSpec,
     SetSpec,
+    SourceColumnRef,
+    TargetColumnRef,
     WhenMatched,
     WhenNotMatched,
     _NormalizedClause,
@@ -159,15 +162,18 @@ def _prepare(target, source, catalog_options, 
when_matched, when_not_matched, on
                         f"condition must not reference blob columns, "
                         f"but found: {sorted(blob_refs)}"
                     )
-    not_matched_specs = [
-        _NormalizedClause(
-            spec=_normalize_set_spec(
-                c.insert, settable_field_names, on_map,
-            ),
-            condition=c.condition,
+    not_matched_specs = []
+    for c in when_not_matched:
+        spec = _normalize_set_spec(
+            c.insert, settable_field_names, on_map,
+            allow_target_refs=False,
+        )
+        for tk, sk in on_map.items():
+            if tk in settable_field_names and tk not in spec:
+                spec[tk] = SourceColumnRef(sk)
+        not_matched_specs.append(
+            _NormalizedClause(spec=spec, condition=c.condition)
         )
-        for c in when_not_matched
-    ]
 
     source_snapshot_id = None
     if isinstance(source, str):
@@ -184,7 +190,7 @@ def _prepare(target, source, catalog_options, when_matched, 
when_not_matched, on
     )
     _validate_source_on_cols(source_ds, source_on_cols)
     _validate_source_has_target_cols(
-        source_ds, settable_field_names, on_map,
+        source_ds, matched_specs + not_matched_specs,
     )
 
     if has_condition:
@@ -410,8 +416,8 @@ def _needed_target_cols(
     set_by_all = set(update_cols)
     for clause in clauses:
         for value in clause.spec.values():
-            if isinstance(value, str) and value.startswith("t."):
-                needed.add(value[2:])
+            if isinstance(value, TargetColumnRef):
+                needed.add(value.column)
         set_by_all &= set(clause.spec.keys())
     needed |= set(update_cols) - set_by_all
     return [c for c in all_target_cols if c in needed]
@@ -439,15 +445,66 @@ def _normalize_set_spec(
     spec: SetSpec,
     target_field_names: Sequence[str],
     on_map: Optional[Mapping[str, str]] = None,
+    allow_target_refs: bool = True,
 ) -> Dict[str, Any]:
     on_map = on_map or {}
-    if spec != "*":
-        raise NotImplementedError(
-            "merge_into currently only supports '*' for update/insert; "
-            "partial SET will be added in a follow-up PR."
+    if spec == "*":
+        return {
+            col: SourceColumnRef(on_map.get(col, col))
+            for col in target_field_names
+        }
+    if not isinstance(spec, Mapping):
+        raise TypeError(
+            f"SET spec must be '*' or a mapping, got {type(spec).__name__}"
         )
-    # A renamed ON key resolves via the source's ON column, not its own name.
-    return {col: f"s.{on_map.get(col, col)}" for col in target_field_names}
+    if not spec:
+        raise ValueError("SET spec must not be empty")
+    target_set = set(target_field_names)
+    for key in spec:
+        if key not in target_set:
+            raise ValueError(
+                f"SET spec references unknown target column '{key}'"
+            )
+    result: Dict[str, Any] = {}
+    for key, val in spec.items():
+        if callable(val) and not isinstance(val, type):
+            raise TypeError(
+                "SET values must be source_col(), target_col(), "
+                "lit(), or literals, not callables"
+            )
+        if isinstance(val, SourceColumnRef):
+            result[key] = val
+        elif isinstance(val, TargetColumnRef):
+            if not allow_target_refs:
+                raise ValueError(
+                    "INSERT spec must not reference target columns "
+                    f"(t.*), but found: 't.{val.column}'"
+                )
+            if val.column not in target_set:
+                raise ValueError(
+                    f"SET spec references unknown target column "
+                    f"'{val.column}'"
+                )
+            result[key] = val
+        elif isinstance(val, LiteralValue):
+            result[key] = val
+        elif isinstance(val, str) and val.startswith("s."):
+            result[key] = SourceColumnRef(val[2:])
+        elif isinstance(val, str) and val.startswith("t."):
+            if not allow_target_refs:
+                raise ValueError(
+                    "INSERT spec must not reference target columns "
+                    f"(t.*), but found: '{val}'"
+                )
+            ref = val[2:]
+            if ref not in target_set:
+                raise ValueError(
+                    f"SET spec references unknown target column '{ref}'"
+                )
+            result[key] = TargetColumnRef(ref)
+        else:
+            result[key] = LiteralValue(val)
+    return result
 
 
 def _normalize_source(
@@ -502,17 +559,16 @@ def _validate_source_on_cols(source_ds, on: 
Sequence[str]) -> None:
 
 def _validate_source_has_target_cols(
     source_ds,
-    target_field_names: Sequence[str],
-    on_map: Mapping[str, str],
+    specs: List[_NormalizedClause],
 ) -> None:
-    """For update='*'/insert='*', source must carry every (non-blob) target
-    column; otherwise the SET spec resolves to null and silently overwrites."""
     names = set(_source_schema_or_raise(source_ds).names)
-    expected = {on_map.get(c, c) for c in target_field_names}
-    missing = sorted(expected - names)
+    needed = set()
+    for clause in specs:
+        for val in clause.spec.values():
+            if isinstance(val, SourceColumnRef):
+                needed.add(val.column)
+    missing = sorted(needed - names)
     if missing:
         raise ValueError(
-            f"source is missing target columns {missing}; "
-            f"update='*'/insert='*' requires the source to carry every "
-            f"(non-blob) target column."
+            f"source is missing columns {missing} referenced by SET spec"
         )
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
index ed786467f1..003977f3e7 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
@@ -25,6 +25,33 @@ SetSpec = Union[str, Mapping[str, Any]]
 OnSpec = Union[Sequence[str], Mapping[str, str]]
 
 
+@dataclass(frozen=True)
+class SourceColumnRef:
+    column: str
+
+
+@dataclass(frozen=True)
+class TargetColumnRef:
+    column: str
+
+
+@dataclass(frozen=True)
+class LiteralValue:
+    value: Any
+
+
+def source_col(name: str) -> SourceColumnRef:
+    return SourceColumnRef(name)
+
+
+def target_col(name: str) -> TargetColumnRef:
+    return TargetColumnRef(name)
+
+
+def lit(value: Any) -> LiteralValue:
+    return LiteralValue(value)
+
+
 @dataclass
 class WhenMatched:
     update: SetSpec
@@ -105,18 +132,19 @@ def _resolve_spec_array(
     on_pairs: Sequence[Tuple[str, str]],
     out_type: pa.DataType,
 ):
-    if isinstance(val, str) and val.startswith("s."):
-        ref = val[2:]
+    if isinstance(val, LiteralValue):
+        return pa.array([val.value] * batch.num_rows, type=out_type)
+    if isinstance(val, SourceColumnRef):
+        ref = val.column
         if f"s.{ref}" in available:
             return batch.column(f"s.{ref}")
         for sk, tk in on_pairs:
             if sk == ref and f"t.{tk}" in available:
                 return batch.column(f"t.{tk}")
         return pa.nulls(batch.num_rows, type=out_type)
-    if isinstance(val, str) and val.startswith("t."):
-        ref = val[2:]
-        col_name = f"t.{ref}"
+    if isinstance(val, TargetColumnRef):
+        col_name = f"t.{val.column}"
         return batch.column(col_name) if col_name in available else pa.nulls(
             batch.num_rows, type=out_type
         )
-    return pa.array([val] * batch.num_rows, type=out_type)
+    raise TypeError(f"unexpected spec value type: {type(val).__name__}")
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py 
b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
index 7be8668320..ca06d43e53 100644
--- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -27,7 +27,10 @@ import pyarrow as pa
 import ray
 
 from pypaimon import CatalogFactory, Schema
-from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into
+from pypaimon.ray import (
+    WhenMatched, WhenNotMatched, merge_into,
+    source_col, target_col, lit,
+)
 
 try:
     import datafusion  # noqa: F401
@@ -846,6 +849,472 @@ class RayDataEvolutionMergeIntoTest(unittest.TestCase):
         self.assertEqual(out['name'], ['y'])
         self.assertEqual(out['age'], [20])
 
+    def test_matched_partial_update(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1, 2], type=pa.int32()),
+                    'name': ['a', 'b'],
+                    'age': pa.array([10, 20], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2], type=pa.int32()),
+                'name': ['a2', 'b2'],
+                'age': pa.array([99, 88], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update={'age': 's.age'})],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2])
+        self.assertEqual(out['name'], ['a', 'b'])
+        self.assertEqual(out['age'], [99, 88])
+
+    def test_insert_partial_mapping(self):
+        target = self._create_table()
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2], type=pa.int32()),
+                'name': ['a', 'b'],
+                'age': pa.array([10, 20], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_not_matched=[
+                WhenNotMatched(insert={'id': 's.id', 'name': 's.name'})
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2])
+        self.assertEqual(out['name'], ['a', 'b'])
+        self.assertEqual(out['age'], [None, None])
+
+    def test_update_with_literal(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['old'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1], type=pa.int32()),
+                'name': ['ignored'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update={'name': 'updated'})],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['updated'])
+        self.assertEqual(out['age'], [10])
+
+    def test_invalid_target_column_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[WhenMatched(update={'nonexistent': 's.id'})],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('nonexistent', str(ctx.exception))
+
+    def test_invalid_target_ref_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[WhenMatched(update={'name': 't.nme'})],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('nme', str(ctx.exception))
+
+    def test_empty_mapping_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError):
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[WhenMatched(update={})],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+
+    def test_insert_target_ref_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_not_matched=[
+                    WhenNotMatched(insert={'name': 't.name'})
+                ],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('t.', str(ctx.exception))
+
+    def test_matched_update_with_target_ref(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['old'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1], type=pa.int32()),
+                'name': ['ignored'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update={'age': 's.age', 'name': 
't.name'})],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['old'])
+        self.assertEqual(out['age'], [99])
+
+    def test_callable_value_rejected(self):
+        target = self._create_table()
+        with self.assertRaises(TypeError):
+            merge_into(
+                target=target,
+                source=self._source(),
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[WhenMatched(update={'name': lambda r: r})],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+
+    def test_source_missing_referenced_col(self):
+        target = self._create_table()
+        source = pa.Table.from_pydict(
+            {'id': pa.array([1], type=pa.int32())},
+            schema=pa.schema([('id', pa.int32())]),
+        )
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=source,
+                catalog_options=self.catalog_options,
+                on=['id'],
+                when_matched=[WhenMatched(update={'name': 's.name'})],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('name', str(ctx.exception))
+
+    def test_partial_insert_auto_fills_on_key(self):
+        target = self._create_table()
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1, 2], type=pa.int32()),
+                'name': ['a', 'b'],
+                'age': pa.array([10, 20], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_not_matched=[
+                WhenNotMatched(insert={'name': 's.name'})
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2])
+        self.assertEqual(out['name'], ['a', 'b'])
+
+    def test_partial_insert_renamed_on_key_auto_filled(self):
+        target = self._create_table()
+
+        source_schema = pa.schema([
+            ('uid', pa.int32()),
+            ('name', pa.string()),
+            ('age', pa.int32()),
+        ])
+        source = pa.Table.from_pydict(
+            {
+                'uid': pa.array([1, 2], type=pa.int32()),
+                'name': ['a', 'b'],
+                'age': pa.array([10, 20], type=pa.int32()),
+            },
+            schema=source_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on={'id': 'uid'},
+            when_not_matched=[
+                WhenNotMatched(insert={'name': 's.name'})
+            ],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['id'], [1, 2])
+        self.assertEqual(out['name'], ['a', 'b'])
+
+    def test_explicit_source_ref_not_remapped_by_on_key(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['old'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source_schema = pa.schema([
+            ('uid', pa.int32()),
+            ('id', pa.int32()),
+            ('name', pa.string()),
+            ('age', pa.int32()),
+        ])
+        source = pa.Table.from_pydict(
+            {
+                'uid': pa.array([1], type=pa.int32()),
+                'id': pa.array([42], type=pa.int32()),
+                'name': ['new'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=source_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on={'id': 'uid'},
+            when_matched=[WhenMatched(update={
+                'age': source_col('id'),
+            })],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['age'], [42])
+        self.assertEqual(out['name'], ['old'])
+
+    def test_renamed_on_key_missing_source_col_rejected(self):
+        target = self._create_table()
+        source_schema = pa.schema([
+            ('uid', pa.int32()),
+            ('name', pa.string()),
+            ('age', pa.int32()),
+        ])
+        source = pa.Table.from_pydict(
+            {
+                'uid': pa.array([1], type=pa.int32()),
+                'name': ['a'],
+                'age': pa.array([10], type=pa.int32()),
+            },
+            schema=source_schema,
+        )
+        with self.assertRaises(ValueError) as ctx:
+            merge_into(
+                target=target,
+                source=source,
+                catalog_options=self.catalog_options,
+                on={'id': 'uid'},
+                when_matched=[WhenMatched(update={
+                    'id': source_col('id'),
+                })],
+                num_partitions=_TEST_NUM_PARTITIONS,
+            )
+        self.assertIn('id', str(ctx.exception))
+
+    def test_lit_prevents_column_ref_interpretation(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['old'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1], type=pa.int32()),
+                'name': ['ignored'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update={
+                'name': lit('s.active'),
+            })],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['s.active'])
+        self.assertEqual(out['age'], [10])
+
+    def test_source_col_helper(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['old'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1], type=pa.int32()),
+                'name': ['new'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update={
+                'age': source_col('age'),
+            })],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['old'])
+        self.assertEqual(out['age'], [99])
+
+    def test_target_col_helper(self):
+        target = self._create_table()
+        self._write(
+            target,
+            pa.Table.from_pydict(
+                {
+                    'id': pa.array([1], type=pa.int32()),
+                    'name': ['keep'],
+                    'age': pa.array([10], type=pa.int32()),
+                },
+                schema=self.pa_schema,
+            ),
+        )
+
+        source = pa.Table.from_pydict(
+            {
+                'id': pa.array([1], type=pa.int32()),
+                'name': ['ignored'],
+                'age': pa.array([99], type=pa.int32()),
+            },
+            schema=self.pa_schema,
+        )
+
+        merge_into(
+            target=target,
+            source=source,
+            catalog_options=self.catalog_options,
+            on=['id'],
+            when_matched=[WhenMatched(update={
+                'age': source_col('age'),
+                'name': target_col('name'),
+            })],
+            num_partitions=_TEST_NUM_PARTITIONS,
+        )
+
+        out = self._read_sorted(target)
+        self.assertEqual(out['name'], ['keep'])
+        self.assertEqual(out['age'], [99])
+
 
 class TargetProjectionTest(unittest.TestCase):
 

Reply via email to