This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a100c0bfe5e7 [SPARK-56651][CONNECT][SDP] Add Python APIs for Auto CDC 
SCD Type 1
a100c0bfe5e7 is described below

commit a100c0bfe5e749442620901ee20d1dc1b9b47aba
Author: AnishMahto <[email protected]>
AuthorDate: Tue May 26 16:45:18 2026 +0800

    [SPARK-56651][CONNECT][SDP] Add Python APIs for Auto CDC SCD Type 1
    
    ## Takeover of https://github.com/apache/spark/pull/56045. PR description 
is copied.
    
    ### What changes were proposed in this pull request?
    Adds `create_auto_cdc_flow` to the the SDP Python API. For now, this will 
only support SCD Type 1. Parameters:
    - name: the name of the flow
    - target: the target table
    - source: the source dataset with the change events
    - keys: the unique key per row,
    - sequence_by: a sequence id to establish time order
    - apply_as_deletes: a boolean expression indicating whether an event 
represents a delete
    - ~~apply_as_truncates: a boolean expression indicating whether an event 
represents a truncation~~
    - column_list: a list of columns to include in the target table
    - except_column_list: a list of columns to exclude from the target table
    - stored_as_scd_type the SCD type, must be 1
    - ~~ignore_null_updates_column_list: a list of columns for which to ignore 
null values~~
    - ~~ignore_null_updates_except_column_list: a list of columns for which not 
to ignore null values~~
    - source_code_location: the location in the Python source code that defines 
this flow
    
    This PR introduces the PySpark API to register an AutoCDC flow within an 
SDP, and send the registration requests to the Spark driver via Spark Connect 
protos.
    
    This PR does not actually handle the reception of said Spark Connect 
protos, and the pipelines handler in the Spark driver will simply throw some 
form of an operation unsupported/unrecognized error.
    
    ### Why are the changes needed?
    See the SPIP at 
https://docs.google.com/document/d/1Hp5BGEYJRHbk6J7XUph3bAPZKRQXKOuV1PEaqZMMRoQ/
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it introduces a new method in the SDP Python API.
    
    ### How was this patch tested?
    Unit tests were added, using a local graph registry.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Generated-by: Claude Sonnet 4.6
    
    Closes #56069 from AnishMahto/SPARK-56651-autocdc-python-api.
    
    Lead-authored-by: AnishMahto <[email protected]>
    Co-authored-by: andreas-neumann_data <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 python/pyspark/pipelines/__init__.py               |   2 +
 python/pyspark/pipelines/api.py                    | 194 ++++++++++++++++++++-
 python/pyspark/pipelines/flow.py                   |  34 +++-
 python/pyspark/pipelines/graph_element_registry.py |   6 +-
 .../spark_connect_graph_element_registry.py        |  40 ++++-
 .../tests/local_graph_element_registry.py          |  10 +-
 .../pipelines/tests/test_graph_element_registry.py | 159 ++++++++++++++++-
 7 files changed, 436 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/pipelines/__init__.py 
b/python/pyspark/pipelines/__init__.py
index d93320e96376..bd41c9ecd6b2 100644
--- a/python/pyspark/pipelines/__init__.py
+++ b/python/pyspark/pipelines/__init__.py
@@ -16,6 +16,7 @@
 #
 from pyspark.pipelines.api import (
     append_flow,
+    create_auto_cdc_flow,
     create_streaming_table,
     materialized_view,
     table,
@@ -25,6 +26,7 @@ from pyspark.pipelines.api import (
 
 __all__ = [
     "append_flow",
+    "create_auto_cdc_flow",
     "create_streaming_table",
     "materialized_view",
     "table",
diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py
index e6bae4f832d5..578b28ec3793 100644
--- a/python/pyspark/pipelines/api.py
+++ b/python/pyspark/pipelines/api.py
@@ -14,12 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from typing import Callable, Dict, List, Optional, Union, overload
+from typing import Callable, Dict, List, Literal, Optional, Union, overload
 
 from pyspark.errors import PySparkTypeError
 from pyspark.pipelines.graph_element_registry import 
get_active_graph_element_registry
 from pyspark.pipelines.type_error_utils import 
validate_optional_list_of_str_arg
-from pyspark.pipelines.flow import Flow, QueryFunction
+from pyspark.pipelines.flow import AutoCdcFlow, Flow, QueryFunction
 from pyspark.pipelines.source_code_location import (
     get_caller_source_code_location,
 )
@@ -29,6 +29,7 @@ from pyspark.pipelines.output import (
     TemporaryView,
     Sink,
 )
+from pyspark.sql import Column
 from pyspark.sql.types import StructType
 
 
@@ -525,3 +526,192 @@ def create_sink(
         comment=None,
     )
     get_active_graph_element_registry().register_output(sink)
+
+
+def create_auto_cdc_flow(
+    target: str,
+    source: str,
+    keys: Union[List[str], List[Column]],
+    sequence_by: Union[str, Column],
+    apply_as_deletes: Optional[Union[str, Column]] = None,
+    column_list: Optional[Union[List[str], List[Column]]] = None,
+    except_column_list: Optional[Union[List[str], List[Column]]] = None,
+    stored_as_scd_type: Optional[Literal[1, "1"]] = None,
+    name: Optional[str] = None,
+) -> None:
+    """
+    Create an Auto CDC flow into the target table from the Change Data Capture 
(CDC) source.
+    Target table must have already been created using create_streaming_table 
function. Only one
+    of column_list and except_column_list can be specified.
+
+    Example:
+        create_auto_cdc_flow(
+            target="target",
+            source="source",
+            keys=["key"],
+            sequence_by="sequence_expr",
+            column_list=["key", "value"],
+        )
+
+    Note that for keys, sequence_by, column_list, and except_column_list the 
arguments have to
+    be column identifiers without qualifiers, e.g. they cannot be 
col("sourceTable.keyId").
+
+    :param target: The name of the target table that receives the Auto CDC 
flow.
+    :param source: The name of the CDC source to stream from.
+    :param keys: The column or combination of columns that uniquely identify a 
row in the source \
+        data. This is used to identify which CDC events apply to specific 
records in the target \
+        table. These keys also identify records in the target table, e.g., if 
there exists a record \
+        for given keys and the CDC source has an UPSERT operation for the same 
keys, we will update \
+        the existing record. At least one key must be provided. This should be 
a list of column \
+        identifiers without qualifiers, expressed as either Python strings or 
PySpark Columns.
+    :param sequence_by: An expression that we use to order the source data. 
This can be expressed \
+        as either a SQL expression string or a PySpark Column.
+    :param apply_as_deletes: A boolean expression indicating whether an event 
represents a \
+        delete. This can be expressed as either a SQL expression string or a 
PySpark Column.
+    :param column_list: Columns that will be included in the output table. 
This should be a list \
+        of column identifiers without qualifiers, expressed as either Python 
strings or PySpark \
+        Columns. Only one of column_list and except_column_list can be 
specified.
+    :param except_column_list: Columns that will be excluded in the output 
table. This should be a \
+        list of column identifiers without qualifiers, expressed as either 
Python strings or \
+        PySpark Columns. Only one of column_list and except_column_list can be 
specified. When \
+        this is specified, all columns in the dataframe of the target table 
except those in this \
+        list will be in the output table.
+    :param stored_as_scd_type: The SCD type for the target table. Only 1 (or 
"1") is supported. \
+        When not specified the server default applies.
+    :param name: The name of the flow for this create_auto_cdc_flow command. 
When unspecified \
+        this will build a "default flow" with name equal to the target name.
+    """
+    from pyspark.sql.connect.functions.builtin import expr as _connect_expr
+
+    if type(target) is not str:
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": "target",
+                "expected_type": "str",
+                "arg_type": type(target).__name__,
+            },
+        )
+    if type(source) is not str:
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": "source",
+                "expected_type": "str",
+                "arg_type": type(source).__name__,
+            },
+        )
+    if name is not None and type(name) is not str:
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": "name",
+                "expected_type": "str",
+                "arg_type": type(name).__name__,
+            },
+        )
+
+    if name is None:
+        name = target
+
+    keys = _normalize_column_list(arg_name="keys", column_list=keys)
+    column_list = _normalize_optional_column_list(arg_name="column_list", 
column_list=column_list)
+    except_column_list = _normalize_optional_column_list(
+        arg_name="except_column_list", column_list=except_column_list
+    )
+
+    if isinstance(sequence_by, str):
+        sequence_by = _connect_expr(sequence_by)
+    elif not isinstance(sequence_by, Column):
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": "sequence_by",
+                "expected_type": "str or Column",
+                "arg_type": type(sequence_by).__name__,
+            },
+        )
+
+    if isinstance(apply_as_deletes, str):
+        apply_as_deletes = _connect_expr(apply_as_deletes)
+    elif apply_as_deletes is not None and not isinstance(apply_as_deletes, 
Column):
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": "apply_as_deletes",
+                "expected_type": "str or Column",
+                "arg_type": type(apply_as_deletes).__name__,
+            },
+        )
+
+    if stored_as_scd_type is not None and str(stored_as_scd_type) != "1":
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": "stored_as_scd_type",
+                "expected_type": "Literal[1, '1']",
+                "arg_type": type(stored_as_scd_type).__name__,
+            },
+        )
+
+    source_code_location = get_caller_source_code_location(stacklevel=1)
+
+    flow = AutoCdcFlow(
+        name=name,
+        target=target,
+        source=source,
+        keys=keys,
+        sequence_by=sequence_by,
+        apply_as_deletes=apply_as_deletes,
+        column_list=column_list,
+        except_column_list=except_column_list,
+        stored_as_scd_type=stored_as_scd_type,
+        source_code_location=source_code_location,
+    )
+
+    get_active_graph_element_registry().register_auto_cdc_flow(flow)
+
+
+def _normalize_optional_column_list(
+    arg_name: str,
+    column_list: Optional[Union[List[str], List[Column]]],
+) -> Optional[List[Column]]:
+    if column_list is None:
+        return None
+    return _normalize_column_list(arg_name=arg_name, column_list=column_list)
+
+
+def _normalize_column_list(
+    arg_name: str,
+    column_list: Union[List[str], List[Column]],
+) -> List[Column]:
+    from pyspark.sql.connect.functions.builtin import col as _connect_col
+
+    if not isinstance(column_list, list):
+        raise PySparkTypeError(
+            errorClass="NOT_EXPECTED_TYPE",
+            messageParameters={
+                "arg_name": arg_name,
+                "expected_type": "list[str] or list[Column]",
+                "arg_type": type(column_list).__name__,
+            },
+        )
+
+    normalized: List[Column] = []
+
+    for column in column_list:
+        if isinstance(column, str):
+            normalized.append(_connect_col(column))
+        elif isinstance(column, Column):
+            normalized.append(column)
+        else:
+            raise PySparkTypeError(
+                errorClass="NOT_EXPECTED_TYPE",
+                messageParameters={
+                    "arg_name": arg_name,
+                    "expected_type": "list[str] or list[Column]",
+                    "arg_type": type(column).__name__,
+                },
+            )
+
+    return normalized
diff --git a/python/pyspark/pipelines/flow.py b/python/pyspark/pipelines/flow.py
index 7c499c0b3622..02e971aedd87 100644
--- a/python/pyspark/pipelines/flow.py
+++ b/python/pyspark/pipelines/flow.py
@@ -15,9 +15,10 @@
 # limitations under the License.
 #
 from dataclasses import dataclass
-from typing import Callable, Dict
+from typing import Callable, Dict, List, Literal, Optional
 
 from pyspark.sql import DataFrame
+from pyspark.sql import Column
 from pyspark.pipelines.source_code_location import SourceCodeLocation
 
 QueryFunction = Callable[[], DataFrame]
@@ -41,3 +42,34 @@ class Flow:
     spark_conf: Dict[str, str]
     source_code_location: SourceCodeLocation
     func: QueryFunction
+
+
+@dataclass(frozen=True)
+class AutoCdcFlow:
+    """Definition of an Auto CDC flow in a pipeline dataflow graph.
+
+    An Auto CDC flow applies Change Data Capture (CDC) events from a source to 
a target
+    streaming table.
+
+    :param name: Optional name of the flow. When None, defaults to the target 
name.
+    :param target: The name of the target streaming table.
+    :param source: The name of the CDC source to stream from.
+    :param keys: Column(s) that uniquely identify a row in source and target 
data.
+    :param sequence_by: Expression used to order the source data.
+    :param apply_as_deletes: Optional delete condition for the merged 
operation.
+    :param column_list: Optional columns to include in the output table.
+    :param except_column_list: Optional columns to exclude from the output 
table.
+    :param stored_as_scd_type: Optional SCD type for the target table. Only 1 
is supported.
+    :param source_code_location: The location of the source code that created 
this flow.
+    """
+
+    name: Optional[str]
+    target: str
+    source: str
+    keys: List[Column]
+    sequence_by: Column
+    apply_as_deletes: Optional[Column]
+    column_list: Optional[List[Column]]
+    except_column_list: Optional[List[Column]]
+    stored_as_scd_type: Optional[Literal[1, "1"]]
+    source_code_location: SourceCodeLocation
diff --git a/python/pyspark/pipelines/graph_element_registry.py 
b/python/pyspark/pipelines/graph_element_registry.py
index 8e311fc2ca98..4eddabaabda0 100644
--- a/python/pyspark/pipelines/graph_element_registry.py
+++ b/python/pyspark/pipelines/graph_element_registry.py
@@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
 from pathlib import Path
 
 from pyspark.pipelines.output import Output
-from pyspark.pipelines.flow import Flow
+from pyspark.pipelines.flow import AutoCdcFlow, Flow
 from contextlib import contextmanager
 from contextvars import ContextVar
 from typing import Generator, Optional
@@ -42,6 +42,10 @@ class GraphElementRegistry(ABC):
     def register_flow(self, flow: Flow) -> None:
         """Add the given flow to the registry."""
 
+    @abstractmethod
+    def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None:
+        """Add the given Auto CDC flow to the registry."""
+
     @abstractmethod
     def register_sql(self, sql_text: str, file_path: Path) -> None:
         """Register a string containing SQL statements the dataflow graph.
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py 
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index ab8831790830..2eef264990a3 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -17,7 +17,7 @@
 from pathlib import Path
 
 from pyspark.errors import PySparkTypeError
-from pyspark.sql import SparkSession
+from pyspark.sql import SparkSession, Column
 from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
 from pyspark.pipelines.output import (
     Output,
@@ -27,12 +27,12 @@ from pyspark.pipelines.output import (
     StreamingTable,
     TemporaryView,
 )
-from pyspark.pipelines.flow import Flow
+from pyspark.pipelines.flow import AutoCdcFlow, Flow
 from pyspark.pipelines.graph_element_registry import GraphElementRegistry
 from pyspark.pipelines.source_code_location import SourceCodeLocation
 from pyspark.sql.connect.types import pyspark_types_to_proto_types
 from pyspark.sql.types import StructType
-from typing import Any, cast
+from typing import Any, List, Optional, cast
 import pyspark.sql.connect.proto as pb2
 from pyspark.pipelines.add_pipeline_analysis_context import 
add_pipeline_analysis_context
 
@@ -133,6 +133,40 @@ class 
SparkConnectGraphElementRegistry(GraphElementRegistry):
         command.pipeline_command.define_flow.CopyFrom(inner_command)
         self._client.execute_command(command)
 
+    def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None:
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
+        def to_plan(col: Column) -> Any:
+            return cast(ConnectColumn, col).to_plan(self._client)
+
+        def to_plans(cols: Optional[List[Column]]) -> list:
+            return [] if cols is None else [to_plan(c) for c in cols]
+
+        auto_cdc_details = pb2.PipelineCommand.DefineFlow.AutoCdcFlowDetails(
+            source=flow.source,
+            keys=to_plans(flow.keys),
+            sequence_by=to_plan(flow.sequence_by),
+            column_list=to_plans(flow.column_list),
+            except_column_list=to_plans(flow.except_column_list),
+        )
+        if flow.stored_as_scd_type is not None:
+            auto_cdc_details.stored_as_scd_type = 
pb2.PipelineCommand.DefineFlow.SCDType.SCD_TYPE_1
+        if flow.apply_as_deletes is not None:
+            
auto_cdc_details.apply_as_deletes.CopyFrom(to_plan(flow.apply_as_deletes))
+
+        inner_command = pb2.PipelineCommand.DefineFlow(
+            dataflow_graph_id=self._dataflow_graph_id,
+            flow_name=flow.name,
+            target_dataset_name=flow.target,
+            auto_cdc_flow_details=auto_cdc_details,
+            sql_conf={},
+            
source_code_location=source_code_location_to_proto(flow.source_code_location),
+        )
+
+        command = pb2.Command()
+        command.pipeline_command.define_flow.CopyFrom(inner_command)
+        self._client.execute_command(command)
+
     def register_sql(self, sql_text: str, file_path: Path) -> None:
         inner_command = pb2.PipelineCommand.DefineSqlGraphElements(
             dataflow_graph_id=self._dataflow_graph_id,
diff --git a/python/pyspark/pipelines/tests/local_graph_element_registry.py 
b/python/pyspark/pipelines/tests/local_graph_element_registry.py
index 0e22641930b9..3b9ea15a1ed6 100644
--- a/python/pyspark/pipelines/tests/local_graph_element_registry.py
+++ b/python/pyspark/pipelines/tests/local_graph_element_registry.py
@@ -20,7 +20,7 @@ from pathlib import Path
 from typing import List, Sequence
 
 from pyspark.pipelines.output import Output
-from pyspark.pipelines.flow import Flow
+from pyspark.pipelines.flow import AutoCdcFlow, Flow
 from pyspark.pipelines.graph_element_registry import GraphElementRegistry
 
 
@@ -34,6 +34,7 @@ class LocalGraphElementRegistry(GraphElementRegistry):
     def __init__(self) -> None:
         self._outputs: List[Output] = []
         self._flows: List[Flow] = []
+        self._auto_cdc_flows: List[AutoCdcFlow] = []
         self._sql_files: List[SqlFile] = []
 
     def register_output(self, output: Output) -> None:
@@ -42,6 +43,9 @@ class LocalGraphElementRegistry(GraphElementRegistry):
     def register_flow(self, flow: Flow) -> None:
         self._flows.append(flow)
 
+    def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None:
+        self._auto_cdc_flows.append(flow)
+
     def register_sql(self, sql_text: str, file_path: Path) -> None:
         self._sql_files.append(SqlFile(sql_text, file_path))
 
@@ -53,6 +57,10 @@ class LocalGraphElementRegistry(GraphElementRegistry):
     def flows(self) -> Sequence[Flow]:
         return self._flows
 
+    @property
+    def auto_cdc_flows(self) -> Sequence[AutoCdcFlow]:
+        return self._auto_cdc_flows
+
     @property
     def sql_files(self) -> Sequence[SqlFile]:
         return self._sql_files
diff --git a/python/pyspark/pipelines/tests/test_graph_element_registry.py 
b/python/pyspark/pipelines/tests/test_graph_element_registry.py
index 1e6fcf224a0a..fd8ed439b130 100644
--- a/python/pyspark/pipelines/tests/test_graph_element_registry.py
+++ b/python/pyspark/pipelines/tests/test_graph_element_registry.py
@@ -17,11 +17,14 @@
 
 import unittest
 
-from pyspark.errors import PySparkException
+from pyspark.errors import PySparkException, PySparkTypeError
 from pyspark.pipelines.graph_element_registry import 
graph_element_registration_context
 from pyspark import pipelines as dp
+from pyspark.pipelines.flow import AutoCdcFlow
 from pyspark.pipelines.output import Sink
 from pyspark.pipelines.tests.local_graph_element_registry import 
LocalGraphElementRegistry
+from pyspark.sql import Column
+from pyspark.sql.connect.functions.builtin import col, expr
 from typing import cast
 
 
@@ -97,6 +100,147 @@ class GraphElementRegistryTest(unittest.TestCase):
         self.assertEqual(sink_obj.options["key1"], "value1")
         assert 
sink_obj.source_code_location.filename.endswith("test_graph_element_registry.py")
 
+    def test_create_auto_cdc_flow(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("target")
+            dp.create_auto_cdc_flow(
+                target="target",
+                source="source",
+                keys=[col("key")],
+                sequence_by=expr("seq"),
+            )
+
+        self.assertEqual(len(registry.outputs), 1)
+        self.assertEqual(len(registry.auto_cdc_flows), 1)
+
+        flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0])
+        self.assertEqual(flow.target, "target")
+        self.assertEqual(flow.source, "source")
+
+        # When name is not specified, it inherits the target's name at 
construction time.
+        self.assertEqual(flow.name, "target")
+        self.assertIsNone(flow.stored_as_scd_type)
+        self.assertIsNone(flow.apply_as_deletes)
+        assert 
flow.source_code_location.filename.endswith("test_graph_element_registry.py")
+
+    def test_create_auto_cdc_flow_with_all_args(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("tgt")
+            dp.create_auto_cdc_flow(
+                target="tgt",
+                source="src",
+                keys=[col("id")],
+                sequence_by=expr("ts"),
+                apply_as_deletes=expr("op = 'DELETE'"),
+                column_list=[col("id"), col("val")],
+                stored_as_scd_type=1,
+                name="my_flow",
+            )
+
+        flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0])
+        self.assertEqual(flow.name, "my_flow")
+        self.assertEqual(flow.stored_as_scd_type, 1)
+
+    def test_create_auto_cdc_flow_with_string_args(self):
+        # Verify that string forms of column / expression arguments are 
normalized to
+        # PySpark Columns, equivalent to passing col(...) / expr(...) directly.
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("tgt")
+            dp.create_auto_cdc_flow(
+                target="tgt",
+                source="src",
+                keys=["id"],
+                sequence_by="ts",
+                apply_as_deletes="op = 'DELETE'",
+                column_list=["id", "val"],
+            )
+
+        flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0])
+        for k in flow.keys:
+            self.assertIsInstance(k, Column)
+        self.assertIsInstance(flow.sequence_by, Column)
+        self.assertIsInstance(flow.apply_as_deletes, Column)
+        assert flow.column_list is not None
+        for c in flow.column_list:
+            self.assertIsInstance(c, Column)
+
+    def test_create_auto_cdc_flow_stored_as_scd_type_string(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("t")
+            dp.create_auto_cdc_flow(
+                target="t",
+                source="s",
+                keys=[col("k")],
+                sequence_by=expr("seq"),
+                stored_as_scd_type="1",
+            )
+
+        flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0])
+        self.assertEqual(flow.stored_as_scd_type, "1")
+
+    def test_create_auto_cdc_flow_invalid_scd_type(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("t")
+            with self.assertRaises(PySparkTypeError) as ctx:
+                dp.create_auto_cdc_flow(
+                    target="t",
+                    source="s",
+                    keys=[col("k")],
+                    sequence_by=expr("seq"),
+                    stored_as_scd_type=2,  # type: ignore[arg-type]
+                )
+            self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE")
+
+    def test_create_auto_cdc_flow_with_except_column_list(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("tgt")
+            dp.create_auto_cdc_flow(
+                target="tgt",
+                source="src",
+                keys=[col("id")],
+                sequence_by=expr("ts"),
+                except_column_list=["op", "ts"],
+            )
+
+        flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0])
+        self.assertIsNone(flow.column_list)
+        assert flow.except_column_list is not None
+        self.assertEqual(len(flow.except_column_list), 2)
+        for c in flow.except_column_list:
+            self.assertIsInstance(c, Column)
+
+    def test_create_auto_cdc_flow_rejects_non_str_target(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("tgt")
+            with self.assertRaises(PySparkTypeError) as ctx:
+                dp.create_auto_cdc_flow(
+                    target=123,  # type: ignore[arg-type]
+                    source="src",
+                    keys=[col("id")],
+                    sequence_by=expr("ts"),
+                )
+            self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE")
+
+    def test_create_auto_cdc_flow_rejects_invalid_key_element(self):
+        registry = LocalGraphElementRegistry()
+        with graph_element_registration_context(registry):
+            dp.create_streaming_table("tgt")
+            with self.assertRaises(PySparkTypeError) as ctx:
+                dp.create_auto_cdc_flow(
+                    target="tgt",
+                    source="src",
+                    keys=[123],  # type: ignore[list-item]
+                    sequence_by=expr("ts"),
+                )
+            self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE")
+
     def test_definition_without_graph_element_registry(self):
         for decorator in [dp.table, dp.temporary_view, dp.materialized_view]:
             with self.assertRaises(PySparkException) as context:
@@ -129,6 +273,19 @@ class GraphElementRegistryTest(unittest.TestCase):
             "GRAPH_ELEMENT_DEFINED_OUTSIDE_OF_DECLARATIVE_PIPELINE",
         )
 
+        with self.assertRaises(PySparkException) as context:
+            dp.create_auto_cdc_flow(
+                target="t",
+                source="s",
+                keys=["k"],
+                sequence_by="seq",
+            )
+
+        self.assertEqual(
+            context.exception.getCondition(),
+            "GRAPH_ELEMENT_DEFINED_OUTSIDE_OF_DECLARATIVE_PIPELINE",
+        )
+
 
 if __name__ == "__main__":
     from pyspark.testing import main


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to