This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a927a14a59e2 [SPARK-52463][SDP] Add support for cluster_by in Python 
Pipelines APIs
a927a14a59e2 is described below

commit a927a14a59e2d07ae10e73f3e28f3ca1b1208929
Author: Sandy Ryza <[email protected]>
AuthorDate: Thu Nov 6 10:45:08 2025 -0800

    [SPARK-52463][SDP] Add support for cluster_by in Python Pipelines APIs
    
    ### What changes were proposed in this pull request?
    
    In the `table` and `materialized_view` decorators, accept a `cluster_by` 
argument that determines the clustering columns.
    
    ### Why are the changes needed?
    
    Parity with the `clusterBy` argument accepted by `DataStreamReader` and 
`DataFrameWriter`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Adds a new parameter to public APIs.
    
    ### How was this patch tested?
    
    Unit tests and integration tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #52831 from sryza/cluster-by.
    
    Authored-by: Sandy Ryza <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/pipelines/api.py                    |  18 +-
 python/pyspark/pipelines/output.py                 |   2 +
 .../spark_connect_graph_element_registry.py        |   1 +
 python/pyspark/sql/connect/proto/pipelines_pb2.py  |  88 ++++----
 python/pyspark/sql/connect/proto/pipelines_pb2.pyi |   9 +
 .../main/protobuf/spark/connect/pipelines.proto    |   3 +
 .../sql/connect/pipelines/PipelinesHandler.scala   |   2 +
 .../connect/pipelines/PythonPipelineSuite.scala    |  37 ++++
 .../connect/pipelines/TestPipelineDefinition.scala |   2 +
 .../spark/sql/pipelines/graph/DatasetManager.scala |  27 ++-
 .../spark/sql/pipelines/graph/FlowExecution.scala  |  16 +-
 .../graph/SqlGraphRegistrationContext.scala        |   3 +
 .../spark/sql/pipelines/graph/elements.scala       |   2 +
 .../pipelines/graph/MaterializeTablesSuite.scala   | 246 ++++++++++++++++++++-
 .../apache/spark/sql/pipelines/utils/APITest.scala |  60 +++++
 .../utils/TestGraphRegistrationContext.scala       |   6 +
 16 files changed, 464 insertions(+), 58 deletions(-)

diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py
index b68cc30b43a7..c01e8524eee2 100644
--- a/python/pyspark/pipelines/api.py
+++ b/python/pyspark/pipelines/api.py
@@ -76,6 +76,7 @@ def _validate_stored_dataset_args(
     name: Optional[str],
     table_properties: Optional[Dict[str, str]],
     partition_cols: Optional[List[str]],
+    cluster_by: Optional[List[str]],
 ) -> None:
     if name is not None and type(name) is not str:
         raise PySparkTypeError(
@@ -91,6 +92,7 @@ def _validate_stored_dataset_args(
             },
         )
     validate_optional_list_of_str_arg(arg_name="partition_cols", 
arg_value=partition_cols)
+    validate_optional_list_of_str_arg(arg_name="cluster_by", 
arg_value=cluster_by)
 
 
 @overload
@@ -107,6 +109,7 @@ def table(
     spark_conf: Optional[Dict[str, str]] = None,
     table_properties: Optional[Dict[str, str]] = None,
     partition_cols: Optional[List[str]] = None,
+    cluster_by: Optional[List[str]] = None,
     schema: Optional[Union[StructType, str]] = None,
 ) -> Callable[[QueryFunction], None]:
     ...
@@ -120,6 +123,7 @@ def table(
     spark_conf: Optional[Dict[str, str]] = None,
     table_properties: Optional[Dict[str, str]] = None,
     partition_cols: Optional[List[str]] = None,
+    cluster_by: Optional[List[str]] = None,
     schema: Optional[Union[StructType, str]] = None,
     format: Optional[str] = None,
 ) -> Union[Callable[[QueryFunction], None], None]:
@@ -142,11 +146,12 @@ def table(
     :param table_properties: A dict where the keys are the property names and 
the values are the \
         property values. These properties will be set on the table.
     :param partition_cols: A list containing the column names of the partition 
columns.
+    :param cluster_by: A list containing the column names of the cluster 
columns.
     :param schema: Explicit Spark SQL schema to materialize this table with. 
Supports either a \
         Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
     :param format: The format of the table, e.g. "parquet".
     """
-    _validate_stored_dataset_args(name, table_properties, partition_cols)
+    _validate_stored_dataset_args(name, table_properties, partition_cols, 
cluster_by)
 
     source_code_location = get_caller_source_code_location(stacklevel=1)
 
@@ -163,6 +168,7 @@ def table(
                 name=resolved_name,
                 table_properties=table_properties or {},
                 partition_cols=partition_cols,
+                cluster_by=cluster_by,
                 schema=schema,
                 source_code_location=source_code_location,
                 format=format,
@@ -209,6 +215,7 @@ def materialized_view(
     spark_conf: Optional[Dict[str, str]] = None,
     table_properties: Optional[Dict[str, str]] = None,
     partition_cols: Optional[List[str]] = None,
+    cluster_by: Optional[List[str]] = None,
     schema: Optional[Union[StructType, str]] = None,
 ) -> Callable[[QueryFunction], None]:
     ...
@@ -222,6 +229,7 @@ def materialized_view(
     spark_conf: Optional[Dict[str, str]] = None,
     table_properties: Optional[Dict[str, str]] = None,
     partition_cols: Optional[List[str]] = None,
+    cluster_by: Optional[List[str]] = None,
     schema: Optional[Union[StructType, str]] = None,
     format: Optional[str] = None,
 ) -> Union[Callable[[QueryFunction], None], None]:
@@ -244,11 +252,12 @@ def materialized_view(
     :param table_properties: A dict where the keys are the property names and 
the values are the \
         property values. These properties will be set on the table.
     :param partition_cols: A list containing the column names of the partition 
columns.
+    :param cluster_by: A list containing the column names of the cluster 
columns.
     :param schema: Explicit Spark SQL schema to materialize this table with. 
Supports either a \
         Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
     :param format: The format of the table, e.g. "parquet".
     """
-    _validate_stored_dataset_args(name, table_properties, partition_cols)
+    _validate_stored_dataset_args(name, table_properties, partition_cols, 
cluster_by)
 
     source_code_location = get_caller_source_code_location(stacklevel=1)
 
@@ -265,6 +274,7 @@ def materialized_view(
                 name=resolved_name,
                 table_properties=table_properties or {},
                 partition_cols=partition_cols,
+                cluster_by=cluster_by,
                 schema=schema,
                 source_code_location=source_code_location,
                 format=format,
@@ -403,6 +413,7 @@ def create_streaming_table(
     comment: Optional[str] = None,
     table_properties: Optional[Dict[str, str]] = None,
     partition_cols: Optional[List[str]] = None,
+    cluster_by: Optional[List[str]] = None,
     schema: Optional[Union[StructType, str]] = None,
     format: Optional[str] = None,
 ) -> None:
@@ -417,6 +428,7 @@ def create_streaming_table(
     :param table_properties: A dict where the keys are the property names and 
the values are the \
         property values. These properties will be set on the table.
     :param partition_cols: A list containing the column names of the partition 
columns.
+    :param cluster_by: A list containing the column names of the cluster 
columns.
     :param schema Explicit Spark SQL schema to materialize this table with. 
Supports either a \
         Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
     :param format: The format of the table, e.g. "parquet".
@@ -435,6 +447,7 @@ def create_streaming_table(
             },
         )
     validate_optional_list_of_str_arg(arg_name="partition_cols", 
arg_value=partition_cols)
+    validate_optional_list_of_str_arg(arg_name="cluster_by", 
arg_value=cluster_by)
 
     source_code_location = get_caller_source_code_location(stacklevel=1)
 
@@ -444,6 +457,7 @@ def create_streaming_table(
         source_code_location=source_code_location,
         table_properties=table_properties or {},
         partition_cols=partition_cols,
+        cluster_by=cluster_by,
         schema=schema,
         format=format,
     )
diff --git a/python/pyspark/pipelines/output.py 
b/python/pyspark/pipelines/output.py
index 84e950f16174..92058e68721f 100644
--- a/python/pyspark/pipelines/output.py
+++ b/python/pyspark/pipelines/output.py
@@ -45,6 +45,7 @@ class Table(Output):
     :param table_properties: A dict where the keys are the property names and 
the values are the
         property values. These properties will be set on the table.
     :param partition_cols: A list containing the column names of the partition 
columns.
+    :param cluster_by: A list containing the column names of the cluster 
columns.
     :param schema Explicit Spark SQL schema to materialize this table with. 
Supports either a
         Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
     :param format: The format of the table, e.g. "parquet".
@@ -52,6 +53,7 @@ class Table(Output):
 
     table_properties: Mapping[str, str]
     partition_cols: Optional[Sequence[str]]
+    cluster_by: Optional[Sequence[str]]
     schema: Optional[Union[StructType, str]]
     format: Optional[str]
 
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py 
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index 5c5ef9fc3040..e8a8561c3e74 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -63,6 +63,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
             table_details = pb2.PipelineCommand.DefineOutput.TableDetails(
                 table_properties=output.table_properties,
                 partition_cols=output.partition_cols,
+                clustering_columns=output.cluster_by,
                 format=output.format,
                 # Even though schema_string is not required, the generated 
Python code seems to
                 # erroneously think it is required.
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py 
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index d7321fa7cf0c..139de83dc1aa 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -42,7 +42,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9c"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
 
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02
 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutp [...]
+    
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcb"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
 
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02
 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutp [...]
 )
 
 _globals = globals()
@@ -69,10 +69,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
     ]._serialized_options = b"8\001"
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = 
b"8\001"
-    _globals["_OUTPUTTYPE"]._serialized_start = 6058
-    _globals["_OUTPUTTYPE"]._serialized_end = 6163
+    _globals["_OUTPUTTYPE"]._serialized_start = 6105
+    _globals["_OUTPUTTYPE"]._serialized_end = 6210
     _globals["_PIPELINECOMMAND"]._serialized_start = 195
-    _globals["_PIPELINECOMMAND"]._serialized_end = 4575
+    _globals["_PIPELINECOMMAND"]._serialized_end = 4622
     _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1129
     _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1437
     
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start 
= 1338
@@ -80,51 +80,51 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1439
     _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1529
     _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_start = 1532
-    _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_end = 2783
+    _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_end = 2830
     _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_start = 
2068
-    _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_end = 
2469
+    _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_end = 
2516
     _globals[
         "_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS_TABLEPROPERTIESENTRY"
-    ]._serialized_start = 2382
+    ]._serialized_start = 2429
     _globals[
         "_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS_TABLEPROPERTIESENTRY"
-    ]._serialized_end = 2448
-    _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_start = 
2472
-    _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_end = 
2681
-    
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_start
 = 2612
-    
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_end
 = 2670
-    _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2786
-    _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 3647
+    ]._serialized_end = 2495
+    _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_start = 
2519
+    _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_end = 
2728
+    
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_start
 = 2659
+    
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_end
 = 2717
+    _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2833
+    _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 3694
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 
1338
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 1396
-    
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_start
 = 3380
-    
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_end
 = 3477
-    _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 3479
-    _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 3537
-    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 3650
-    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3972
-    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 
3975
-    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 4174
-    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
 = 4177
-    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
 = 4335
-    
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 
4338
-    _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end 
= 4559
-    _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 4578
-    _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 5330
-    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start 
= 4947
-    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 
5045
-    _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_start = 
5048
-    _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_end = 
5181
-    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 
5184
-    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 5315
-    _globals["_PIPELINEEVENTRESULT"]._serialized_start = 5332
-    _globals["_PIPELINEEVENTRESULT"]._serialized_end = 5405
-    _globals["_PIPELINEEVENT"]._serialized_start = 5407
-    _globals["_PIPELINEEVENT"]._serialized_end = 5523
-    _globals["_SOURCECODELOCATION"]._serialized_start = 5526
-    _globals["_SOURCECODELOCATION"]._serialized_end = 5767
-    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5769
-    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5838
-    _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5841
-    _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6056
+    
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_start
 = 3427
+    
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_end
 = 3524
+    _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 3526
+    _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 3584
+    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 3697
+    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 4019
+    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 
4022
+    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 4221
+    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
 = 4224
+    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
 = 4382
+    
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 
4385
+    _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end 
= 4606
+    _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 4625
+    _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 5377
+    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start 
= 4994
+    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 
5092
+    _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_start = 
5095
+    _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_end = 
5228
+    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 
5231
+    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 5362
+    _globals["_PIPELINEEVENTRESULT"]._serialized_start = 5379
+    _globals["_PIPELINEEVENTRESULT"]._serialized_end = 5452
+    _globals["_PIPELINEEVENT"]._serialized_start = 5454
+    _globals["_PIPELINEEVENT"]._serialized_end = 5570
+    _globals["_SOURCECODELOCATION"]._serialized_start = 5573
+    _globals["_SOURCECODELOCATION"]._serialized_end = 5814
+    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5816
+    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5885
+    _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5888
+    _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6103
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi 
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index b9170e763ed9..60d131037c99 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -240,6 +240,7 @@ class PipelineCommand(google.protobuf.message.Message):
             FORMAT_FIELD_NUMBER: builtins.int
             SCHEMA_DATA_TYPE_FIELD_NUMBER: builtins.int
             SCHEMA_STRING_FIELD_NUMBER: builtins.int
+            CLUSTERING_COLUMNS_FIELD_NUMBER: builtins.int
             @property
             def table_properties(
                 self,
@@ -255,6 +256,11 @@ class PipelineCommand(google.protobuf.message.Message):
             @property
             def schema_data_type(self) -> 
pyspark.sql.connect.proto.types_pb2.DataType: ...
             schema_string: builtins.str
+            @property
+            def clustering_columns(
+                self,
+            ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+                """Optional cluster columns for the table."""
             def __init__(
                 self,
                 *,
@@ -263,6 +269,7 @@ class PipelineCommand(google.protobuf.message.Message):
                 format: builtins.str | None = ...,
                 schema_data_type: pyspark.sql.connect.proto.types_pb2.DataType 
| None = ...,
                 schema_string: builtins.str = ...,
+                clustering_columns: collections.abc.Iterable[builtins.str] | 
None = ...,
             ) -> None: ...
             def HasField(
                 self,
@@ -284,6 +291,8 @@ class PipelineCommand(google.protobuf.message.Message):
                 field_name: typing_extensions.Literal[
                     "_format",
                     b"_format",
+                    "clustering_columns",
+                    b"clustering_columns",
                     "format",
                     b"format",
                     "partition_cols",
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index c6a5e571f979..0fa36f8a1514 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -104,6 +104,9 @@ message PipelineCommand {
         spark.connect.DataType schema_data_type = 4;
         string schema_string = 5;
       }
+
+      // Optional cluster columns for the table.
+      repeated string clustering_columns = 6;
     }
 
     // Metadata that's only applicable to sinks.
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 7e69e546893e..0929b07be523 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -203,6 +203,8 @@ private[connect] object PipelinesHandler extends Logging {
             },
             partitionCols = 
Option(tableDetails.getPartitionColsList.asScala.toSeq)
               .filter(_.nonEmpty),
+            clusterCols = 
Option(tableDetails.getClusteringColumnsList.asScala.toSeq)
+              .filter(_.nonEmpty),
             properties = tableDetails.getTablePropertiesMap.asScala.toMap,
             origin = QueryOrigin(
               filePath = Option.when(output.getSourceCodeLocation.hasFileName)(
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index 79c34ac46b9f..1a72d112aa2e 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -865,4 +865,41 @@ class PythonPipelineSuite
 
     (exitCode, output.toSeq)
   }
+
+  test("empty cluster_by list should work and create table with no 
clustering") {
+    withTable("mv", "st") {
+      val graph = buildGraph("""
+            |from pyspark.sql.functions import col
+            |
+            |@dp.materialized_view(cluster_by = [])
+            |def mv():
+            |  return spark.range(5).withColumn("id_mod", col("id") % 2)
+            |
+            |@dp.table(cluster_by = [])
+            |def st():
+            |  return spark.readStream.table("mv")
+            |""".stripMargin)
+      val updateContext =
+        new PipelineUpdateContextImpl(graph, eventCallback = _ => (), 
storageRoot = storageRoot)
+      updateContext.pipelineExecution.runPipeline()
+      updateContext.pipelineExecution.awaitCompletion()
+
+      // Check tables are created with no clustering transforms
+      val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+
+      val mvIdentifier = Identifier.of(Array("default"), "mv")
+      val mvTable = catalog.loadTable(mvIdentifier)
+      val mvTransforms = mvTable.partitioning()
+      assert(
+        mvTransforms.isEmpty,
+        s"MaterializedView should have no transforms, but got: 
${mvTransforms.mkString(", ")}")
+
+      val stIdentifier = Identifier.of(Array("default"), "st")
+      val stTable = catalog.loadTable(stIdentifier)
+      val stTransforms = stTable.partitioning()
+      assert(
+        stTransforms.isEmpty,
+        s"Table should have no transforms, but got: ${stTransforms.mkString(", 
")}")
+    }
+  }
 }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
index dfb766b1df77..f3b63f791421 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
@@ -41,10 +41,12 @@ class TestPipelineDefinition(graphId: String) {
       // TODO: Add support for specifiedSchema
       // specifiedSchema: Option[StructType] = None,
       partitionCols: Option[Seq[String]] = None,
+      clusterCols: Option[Seq[String]] = None,
       properties: Map[String, String] = Map.empty): Unit = {
     val tableDetails = sc.PipelineCommand.DefineOutput.TableDetails
       .newBuilder()
       .addAllPartitionCols(partitionCols.getOrElse(Seq()).asJava)
+      .addAllClusteringColumns(clusterCols.getOrElse(Seq()).asJava)
       .putAllTableProperties(properties.asJava)
       .build()
 
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
index cb142988ce51..e5c87fa542ad 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.connector.catalog.{
   TableInfo
 }
 import 
org.apache.spark.sql.connector.catalog.CatalogV2Util.v2ColumnsToStructType
-import org.apache.spark.sql.connector.expressions.Expressions
+import org.apache.spark.sql.connector.expressions.{ClusterByTransform, 
Expressions}
 import org.apache.spark.sql.execution.command.CreateViewCommand
 import org.apache.spark.sql.pipelines.graph.QueryOrigin.ExceptionHelpers
 import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils.diffSchemas
@@ -266,6 +266,19 @@ object DatasetManager extends Logging {
     )
     val mergedProperties = resolveTableProperties(table, identifier)
     val partitioning = 
table.partitionCols.toSeq.flatten.map(Expressions.identity)
+    val clustering = table.clusterCols.map(cols =>
+      ClusterByTransform(cols.map(col => Expressions.column(col)))
+    ).toSeq
+
+    // Validate that partition and cluster columns don't coexist
+    if (partitioning.nonEmpty && clustering.nonEmpty) {
+      throw new AnalysisException(
+        errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED",
+        messageParameters = Map.empty
+      )
+    }
+
+    val allTransforms = partitioning ++ clustering
 
     val existingTableOpt = if (catalog.tableExists(identifier)) {
       Some(catalog.loadTable(identifier))
@@ -273,15 +286,15 @@ object DatasetManager extends Logging {
       None
     }
 
-    // Error if partitioning doesn't match
+    // Error if partitioning/clustering doesn't match
     if (existingTableOpt.isDefined) {
-      val existingPartitioning = existingTableOpt.get.partitioning().toSeq
-      if (existingPartitioning != partitioning) {
+      val existingTransforms = existingTableOpt.get.partitioning().toSeq
+      if (existingTransforms != allTransforms) {
         throw new AnalysisException(
           errorClass = "CANNOT_UPDATE_PARTITION_COLUMNS",
           messageParameters = Map(
-            "existingPartitionColumns" -> existingPartitioning.mkString(", "),
-            "requestedPartitionColumns" -> partitioning.mkString(", ")
+            "existingPartitionColumns" -> existingTransforms.mkString(", "),
+            "requestedPartitionColumns" -> allTransforms.mkString(", ")
           )
         )
       }
@@ -314,7 +327,7 @@ object DatasetManager extends Logging {
         new TableInfo.Builder()
           .withProperties(mergedProperties.asJava)
           .withColumns(CatalogV2Util.structTypeToV2Columns(outputSchema))
-          .withPartitions(partitioning.toArray)
+          .withPartitions(allTransforms.toArray)
           .build()
       )
     }
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
index 2c9029fdd34d..647df79bb940 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
@@ -264,12 +264,20 @@ class BatchTableWrite(
         if (destination.format.isDefined) {
           dataFrameWriter.format(destination.format.get)
         }
+
+        // In "append" mode with saveAsTable, partition/cluster columns must 
be specified in query
+        // because the format and options of the existing table is used, and 
the table could
+        // have been created with partition columns.
+        if (destination.clusterCols.isDefined) {
+          val clusterCols = destination.clusterCols.get
+          dataFrameWriter.clusterBy(clusterCols.head, clusterCols.tail: _*)
+        }
+        if (destination.partitionCols.isDefined) {
+          dataFrameWriter.partitionBy(destination.partitionCols.get: _*)
+        }
+
         dataFrameWriter
           .mode("append")
-          // In "append" mode with saveAsTable, partition columns must be 
specified in query
-          // because the format and options of the existing table is used, and 
the table could
-          // have been created with partition columns.
-          .partitionBy(destination.partitionCols.getOrElse(Seq.empty): _*)
           .saveAsTable(destination.identifier.unquotedString)
       }
     }
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
index 55a03a2d19f9..5df12be7f4cf 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
@@ -192,6 +192,7 @@ class SqlGraphRegistrationContext(
           specifiedSchema =
             
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
           partitionCols = 
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
+          clusterCols = None,
           properties = cst.tableSpec.properties,
           origin = queryOrigin.copy(
             objectName = Option(stIdentifier.unquotedString),
@@ -223,6 +224,7 @@ class SqlGraphRegistrationContext(
           specifiedSchema =
             
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
           partitionCols = 
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
+          clusterCols = None,
           properties = cst.tableSpec.properties,
           origin = queryOrigin.copy(
             objectName = Option(stIdentifier.unquotedString),
@@ -273,6 +275,7 @@ class SqlGraphRegistrationContext(
           specifiedSchema =
             
Option.when(cmv.columns.nonEmpty)(StructType(cmv.columns.map(_.toV1Column))),
           partitionCols = 
Option(PartitionHelper.applyPartitioning(cmv.partitioning, queryOrigin)),
+          clusterCols = None,
           properties = cmv.tableSpec.properties,
           origin = queryOrigin.copy(
             objectName = Option(mvIdentifier.unquotedString),
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
index 87e01ed2021e..c762174e6725 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
@@ -114,6 +114,7 @@ sealed trait TableInput extends Input {
  * @param identifier The identifier of this table within the graph.
  * @param specifiedSchema The user-specified schema for this table.
  * @param partitionCols What columns the table should be partitioned by when 
materialized.
+ * @param clusterCols What columns the table should be clustered by when 
materialized.
  * @param normalizedPath Normalized storage location for the table based on 
the user-specified table
  *                       path (if not defined, we will normalize a managed 
storage path for it).
  * @param properties Table Properties to set in table metadata.
@@ -124,6 +125,7 @@ case class Table(
     identifier: TableIdentifier,
     specifiedSchema: Option[StructType],
     partitionCols: Option[Seq[String]],
+    clusterCols: Option[Seq[String]],
     normalizedPath: Option[String],
     properties: Map[String, String] = Map.empty,
     comment: Option[String],
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index 31afc5a27a54..ba8419eb6e9c 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.pipelines.graph
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.SparkThrowable
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, 
TableCatalog}
-import org.apache.spark.sql.connector.expressions.Expressions
+import org.apache.spark.sql.connector.expressions.{ClusterByTransform, 
Expressions, FieldReference}
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
 import 
org.apache.spark.sql.pipelines.graph.DatasetManager.TableMaterializationException
 import org.apache.spark.sql.pipelines.utils.{BaseCoreExecutionTest, 
TestGraphRegistrationContext}
@@ -885,4 +886,247 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       storageRoot = storageRoot
     )
   }
+
+  test("cluster columns with user schema") {
+    val session = spark
+    import session.implicits._
+
+    materializeGraph(
+      new TestGraphRegistrationContext(spark) {
+        registerTable(
+          "a",
+          query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1", 
"x2", "x3"))),
+          specifiedSchema = Option(
+            new StructType()
+              .add("x1", IntegerType)
+              .add("x2", IntegerType)
+              .add("x3", StringType)
+          ),
+          clusterCols = Option(Seq("x1", "x3"))
+        )
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
+    )
+    val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+    val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")
+    val table = catalog.loadTable(identifier)
+    assert(
+      table.columns() sameElements CatalogV2Util.structTypeToV2Columns(
+        new StructType()
+          .add("x1", IntegerType)
+          .add("x2", IntegerType)
+          .add("x3", StringType)
+      )
+    )
+    val expectedClusterTransform = ClusterByTransform(
+      Seq(FieldReference("x1"), FieldReference("x3")).toSeq
+    )
+    assert(table.partitioning().contains(expectedClusterTransform))
+  }
+
+  test("specifying cluster column with existing clustered table") {
+    val session = spark
+    import session.implicits._
+
+    materializeGraph(
+      new TestGraphRegistrationContext(spark) {
+        registerTable(
+          "t10",
+          query = Option(dfFlowFunc(Seq((1, true, "a"), (2, false, 
"b")).toDF("x", "y", "z"))),
+          clusterCols = Option(Seq("x", "z"))
+        )
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
+    )
+
+    val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+    val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t10")
+    val table = catalog.loadTable(identifier)
+    val expectedClusterTransform = ClusterByTransform(
+      Seq(FieldReference("x"), FieldReference("z")).toSeq
+    )
+    assert(table.partitioning().contains(expectedClusterTransform))
+
+    // Specify the same cluster columns - should work
+    materializeGraph(
+      new TestGraphRegistrationContext(spark) {
+        registerFlow(
+          "t10",
+          "t10",
+          query = dfFlowFunc(Seq((3, true, "c"), (4, false, "d")).toDF("x", 
"y", "z"))
+        )
+        registerTable("t10", clusterCols = Option(Seq("x", "z")))
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
+    )
+
+    val table2 = catalog.loadTable(identifier)
+    assert(table2.partitioning().contains(expectedClusterTransform))
+
+    // Don't specify cluster columns when table already has them - should throw
+    val ex = intercept[TableMaterializationException] {
+      materializeGraph(
+        new TestGraphRegistrationContext(spark) {
+          registerFlow(
+            "t10",
+            "t10",
+            query = dfFlowFunc(Seq((5, true, "e"), (6, false, "f")).toDF("x", 
"y", "z"))
+          )
+          registerTable("t10")
+        }.resolveToDataflowGraph(),
+        storageRoot = storageRoot
+      )
+    }
+    assert(ex.cause.asInstanceOf[SparkThrowable].getCondition == 
"CANNOT_UPDATE_PARTITION_COLUMNS")
+  }
+
+  test("specifying cluster column different from existing clustered table") {
+    val session = spark
+    import session.implicits._
+
+    materializeGraph(
+      new TestGraphRegistrationContext(spark) {
+        registerTable(
+          "t11",
+          query = Option(dfFlowFunc(Seq((1, true, "a"), (2, false, 
"b")).toDF("x", "y", "z"))),
+          clusterCols = Option(Seq("x"))
+        )
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
+    )
+
+    val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+    val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t11")
+
+    // Specify different cluster columns - should throw
+    val ex = intercept[TableMaterializationException] {
+      materializeGraph(
+        new TestGraphRegistrationContext(spark) {
+          registerFlow(
+            "t11",
+            "t11",
+            query = dfFlowFunc(Seq((3, true, "c"), (4, false, "d")).toDF("x", 
"y", "z"))
+          )
+          registerTable("t11", clusterCols = Option(Seq("y")))
+        }.resolveToDataflowGraph(),
+        storageRoot = storageRoot
+      )
+    }
+    assert(ex.cause.asInstanceOf[SparkThrowable].getCondition == 
"CANNOT_UPDATE_PARTITION_COLUMNS")
+
+    val table = catalog.loadTable(identifier)
+    val expectedClusterTransform = 
ClusterByTransform(Seq(FieldReference("x")).toSeq)
+    assert(table.partitioning().contains(expectedClusterTransform))
+  }
+
+  test("cluster columns only (no partitioning)") {
+    val session = spark
+    import session.implicits._
+
+    materializeGraph(
+      new TestGraphRegistrationContext(spark) {
+        registerTable(
+          "t12",
+          query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1", 
"x2", "x3"))),
+          specifiedSchema = Option(
+            new StructType()
+              .add("x1", IntegerType)
+              .add("x2", IntegerType)
+              .add("x3", StringType)
+          ),
+          clusterCols = Option(Seq("x1", "x3"))
+        )
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
+    )
+    val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+    val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t12")
+    val table = catalog.loadTable(identifier)
+    assert(
+      table.columns() sameElements CatalogV2Util.structTypeToV2Columns(
+        new StructType()
+          .add("x1", IntegerType)
+          .add("x2", IntegerType)
+          .add("x3", StringType)
+      )
+    )
+
+    val transforms = table.partitioning()
+    val expectedClusterTransform = ClusterByTransform(
+      Seq(FieldReference("x1"), FieldReference("x3")).toSeq
+    )
+    assert(transforms.contains(expectedClusterTransform))
+  }
+
+  test("materialized view with cluster columns") {
+    val session = spark
+    import session.implicits._
+
+    materializeGraph(
+      new TestGraphRegistrationContext(spark) {
+        registerMaterializedView(
+          "mv1",
+          query = dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1", "x2", 
"x3")),
+          clusterCols = Option(Seq("x1", "x2"))
+        )
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
+    )
+    val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+    val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "mv1")
+    val table = catalog.loadTable(identifier)
+    assert(
+      table.columns() sameElements CatalogV2Util.structTypeToV2Columns(
+        new StructType()
+          .add("x1", IntegerType)
+          .add("x2", IntegerType)
+          .add("x3", StringType)
+      )
+    )
+    val expectedClusterTransform = ClusterByTransform(
+      Seq(FieldReference("x1"), FieldReference("x2")).toSeq
+    )
+    assert(table.partitioning().contains(expectedClusterTransform))
+  }
+
+  test("partition and cluster columns together should fail") {
+    val session = spark
+    import session.implicits._
+
+    val ex = intercept[TableMaterializationException] {
+      materializeGraph(
+        new TestGraphRegistrationContext(spark) {
+          registerTable(
+            "invalid_table",
+            query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1", 
"x2", "x3"))),
+            partitionCols = Option(Seq("x2")),
+            clusterCols = Option(Seq("x1", "x3"))
+          )
+        }.resolveToDataflowGraph(),
+        storageRoot = storageRoot
+      )
+    }
+    assert(ex.cause.isInstanceOf[AnalysisException])
+    val analysisEx = ex.cause.asInstanceOf[AnalysisException]
+    assert(analysisEx.errorClass.get == 
"SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED")
+  }
+
+  test("cluster column that doesn't exist in table schema should fail") {
+    val session = spark
+    import session.implicits._
+
+    val ex = intercept[TableMaterializationException] {
+      materializeGraph(
+        new TestGraphRegistrationContext(spark) {
+          registerTable(
+            "invalid_cluster_table",
+            query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1", 
"x2", "x3"))),
+            clusterCols = Option(Seq("nonexistent_column"))
+          )
+        }.resolveToDataflowGraph(),
+        storageRoot = storageRoot
+      )
+    }
+    assert(ex.cause.isInstanceOf[AnalysisException])
+  }
 }
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
index bb7c8e833f84..c6b457ee04eb 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
@@ -542,6 +542,66 @@ trait APITest
     }
   }
 
+  test("Python Pipeline with cluster columns") {
+    val pipelineSpec =
+      TestPipelineSpec(include = Seq("transformations/**"))
+    val pipelineConfig = TestPipelineConfiguration(pipelineSpec)
+    val sources = Seq(
+      PipelineSourceFile(
+        name = "transformations/definition.py",
+        contents = """
+                     |from pyspark import pipelines as dp
+                     |from pyspark.sql import DataFrame, SparkSession
+                     |from pyspark.sql.functions import col
+                     |
+                     |spark = SparkSession.active()
+                     |
+                     |@dp.materialized_view(cluster_by = ["cluster_col1"])
+                     |def mv():
+                     |  df = spark.range(10)
+                     |  df = df.withColumn("cluster_col1", col("id") % 3)
+                     |  df = df.withColumn("cluster_col2", col("id") % 2)
+                     |  return df
+                     |
+                     |@dp.table(cluster_by = ["cluster_col1"])
+                     |def st():
+                     |  return spark.readStream.table("mv")
+                     |""".stripMargin))
+    val pipeline = createAndRunPipeline(pipelineConfig, sources)
+    awaitPipelineTermination(pipeline)
+
+    // Verify tables have correct data
+    Seq("mv", "st").foreach { tbl =>
+      val fullName = s"$tbl"
+      checkAnswer(
+        spark.sql(s"SELECT * FROM $fullName ORDER BY id"),
+        Seq(
+          Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 0), Row(3, 0, 1), Row(4, 1, 0),
+          Row(5, 2, 1), Row(6, 0, 0), Row(7, 1, 1), Row(8, 2, 0), Row(9, 0, 1)
+        ))
+    }
+
+    // Verify clustering information is stored in catalog
+    val catalog = spark.sessionState.catalogManager.currentCatalog
+      .asInstanceOf[org.apache.spark.sql.connector.catalog.TableCatalog]
+    // Check materialized view has clustering transform
+    val mvIdentifier = org.apache.spark.sql.connector.catalog.Identifier
+      .of(Array("default"), "mv")
+    val mvTable = catalog.loadTable(mvIdentifier)
+    val mvTransforms = mvTable.partitioning()
+    assert(mvTransforms.length == 1)
+    assert(mvTransforms.head.name() == "cluster_by")
+    assert(mvTransforms.head.toString.contains("cluster_col1"))
+    // Check streaming table has clustering transform
+    val stIdentifier = org.apache.spark.sql.connector.catalog.Identifier
+      .of(Array("default"), "st")
+    val stTable = catalog.loadTable(stIdentifier)
+    val stTransforms = stTable.partitioning()
+    assert(stTransforms.length == 1)
+    assert(stTransforms.head.name() == "cluster_by")
+    assert(stTransforms.head.toString.contains("cluster_col1"))
+  }
+
   /* Below tests pipeline execution configurations */
 
   test("Pipeline with dry run") {
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index 599aab87d1f7..d88432d68ca3 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -46,6 +46,7 @@ class TestGraphRegistrationContext(
       comment: Option[String] = None,
       specifiedSchema: Option[StructType] = None,
       partitionCols: Option[Seq[String]] = None,
+      clusterCols: Option[Seq[String]] = None,
       properties: Map[String, String] = Map.empty,
       baseOrigin: QueryOrigin = QueryOrigin.empty,
       format: Option[String] = None,
@@ -58,6 +59,7 @@ class TestGraphRegistrationContext(
     comment,
     specifiedSchema,
     partitionCols,
+    clusterCols,
     properties,
     baseOrigin,
     format,
@@ -99,6 +101,7 @@ class TestGraphRegistrationContext(
       comment: Option[String] = None,
       specifiedSchema: Option[StructType] = None,
       partitionCols: Option[Seq[String]] = None,
+      clusterCols: Option[Seq[String]] = None,
       properties: Map[String, String] = Map.empty,
       baseOrigin: QueryOrigin = QueryOrigin.empty,
       format: Option[String] = None,
@@ -111,6 +114,7 @@ class TestGraphRegistrationContext(
     comment,
     specifiedSchema,
     partitionCols,
+    clusterCols,
     properties,
     baseOrigin,
     format,
@@ -129,6 +133,7 @@ class TestGraphRegistrationContext(
       comment: Option[String],
       specifiedSchema: Option[StructType],
       partitionCols: Option[Seq[String]],
+      clusterCols: Option[Seq[String]],
       properties: Map[String, String],
       baseOrigin: QueryOrigin,
       format: Option[String],
@@ -150,6 +155,7 @@ class TestGraphRegistrationContext(
         comment = comment,
         specifiedSchema = specifiedSchema,
         partitionCols = partitionCols,
+        clusterCols = clusterCols,
         properties = properties,
         origin = baseOrigin.merge(
           QueryOrigin(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to