This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a927a14a59e2 [SPARK-52463][SDP] Add support for cluster_by in Python
Pipelines APIs
a927a14a59e2 is described below
commit a927a14a59e2d07ae10e73f3e28f3ca1b1208929
Author: Sandy Ryza <[email protected]>
AuthorDate: Thu Nov 6 10:45:08 2025 -0800
[SPARK-52463][SDP] Add support for cluster_by in Python Pipelines APIs
### What changes were proposed in this pull request?
In the `table` and `materialized_view` decorators, accept a `cluster_by`
argument that determines the clustering columns.
### Why are the changes needed?
Parity with the `clusterBy` argument accepted by `DataStreamReader` and
`DataFrameWriter`.
### Does this PR introduce _any_ user-facing change?
Adds a new parameter to public APIs.
### How was this patch tested?
Unit tests and integration tests.
### Was this patch authored or co-authored using generative AI tooling?
Closes #52831 from sryza/cluster-by.
Authored-by: Sandy Ryza <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/pipelines/api.py | 18 +-
python/pyspark/pipelines/output.py | 2 +
.../spark_connect_graph_element_registry.py | 1 +
python/pyspark/sql/connect/proto/pipelines_pb2.py | 88 ++++----
python/pyspark/sql/connect/proto/pipelines_pb2.pyi | 9 +
.../main/protobuf/spark/connect/pipelines.proto | 3 +
.../sql/connect/pipelines/PipelinesHandler.scala | 2 +
.../connect/pipelines/PythonPipelineSuite.scala | 37 ++++
.../connect/pipelines/TestPipelineDefinition.scala | 2 +
.../spark/sql/pipelines/graph/DatasetManager.scala | 27 ++-
.../spark/sql/pipelines/graph/FlowExecution.scala | 16 +-
.../graph/SqlGraphRegistrationContext.scala | 3 +
.../spark/sql/pipelines/graph/elements.scala | 2 +
.../pipelines/graph/MaterializeTablesSuite.scala | 246 ++++++++++++++++++++-
.../apache/spark/sql/pipelines/utils/APITest.scala | 60 +++++
.../utils/TestGraphRegistrationContext.scala | 6 +
16 files changed, 464 insertions(+), 58 deletions(-)
diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py
index b68cc30b43a7..c01e8524eee2 100644
--- a/python/pyspark/pipelines/api.py
+++ b/python/pyspark/pipelines/api.py
@@ -76,6 +76,7 @@ def _validate_stored_dataset_args(
name: Optional[str],
table_properties: Optional[Dict[str, str]],
partition_cols: Optional[List[str]],
+ cluster_by: Optional[List[str]],
) -> None:
if name is not None and type(name) is not str:
raise PySparkTypeError(
@@ -91,6 +92,7 @@ def _validate_stored_dataset_args(
},
)
validate_optional_list_of_str_arg(arg_name="partition_cols",
arg_value=partition_cols)
+ validate_optional_list_of_str_arg(arg_name="cluster_by",
arg_value=cluster_by)
@overload
@@ -107,6 +109,7 @@ def table(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
) -> Callable[[QueryFunction], None]:
...
@@ -120,6 +123,7 @@ def table(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> Union[Callable[[QueryFunction], None], None]:
@@ -142,11 +146,12 @@ def table(
:param table_properties: A dict where the keys are the property names and
the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition
columns.
+ :param cluster_by: A list containing the column names of the cluster
columns.
:param schema: Explicit Spark SQL schema to materialize this table with.
Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
- _validate_stored_dataset_args(name, table_properties, partition_cols)
+ _validate_stored_dataset_args(name, table_properties, partition_cols,
cluster_by)
source_code_location = get_caller_source_code_location(stacklevel=1)
@@ -163,6 +168,7 @@ def table(
name=resolved_name,
table_properties=table_properties or {},
partition_cols=partition_cols,
+ cluster_by=cluster_by,
schema=schema,
source_code_location=source_code_location,
format=format,
@@ -209,6 +215,7 @@ def materialized_view(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
) -> Callable[[QueryFunction], None]:
...
@@ -222,6 +229,7 @@ def materialized_view(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> Union[Callable[[QueryFunction], None], None]:
@@ -244,11 +252,12 @@ def materialized_view(
:param table_properties: A dict where the keys are the property names and
the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition
columns.
+ :param cluster_by: A list containing the column names of the cluster
columns.
:param schema: Explicit Spark SQL schema to materialize this table with.
Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
- _validate_stored_dataset_args(name, table_properties, partition_cols)
+ _validate_stored_dataset_args(name, table_properties, partition_cols,
cluster_by)
source_code_location = get_caller_source_code_location(stacklevel=1)
@@ -265,6 +274,7 @@ def materialized_view(
name=resolved_name,
table_properties=table_properties or {},
partition_cols=partition_cols,
+ cluster_by=cluster_by,
schema=schema,
source_code_location=source_code_location,
format=format,
@@ -403,6 +413,7 @@ def create_streaming_table(
comment: Optional[str] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> None:
@@ -417,6 +428,7 @@ def create_streaming_table(
:param table_properties: A dict where the keys are the property names and
the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition
columns.
+ :param cluster_by: A list containing the column names of the cluster
columns.
:param schema Explicit Spark SQL schema to materialize this table with.
Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
@@ -435,6 +447,7 @@ def create_streaming_table(
},
)
validate_optional_list_of_str_arg(arg_name="partition_cols",
arg_value=partition_cols)
+ validate_optional_list_of_str_arg(arg_name="cluster_by",
arg_value=cluster_by)
source_code_location = get_caller_source_code_location(stacklevel=1)
@@ -444,6 +457,7 @@ def create_streaming_table(
source_code_location=source_code_location,
table_properties=table_properties or {},
partition_cols=partition_cols,
+ cluster_by=cluster_by,
schema=schema,
format=format,
)
diff --git a/python/pyspark/pipelines/output.py
b/python/pyspark/pipelines/output.py
index 84e950f16174..92058e68721f 100644
--- a/python/pyspark/pipelines/output.py
+++ b/python/pyspark/pipelines/output.py
@@ -45,6 +45,7 @@ class Table(Output):
:param table_properties: A dict where the keys are the property names and
the values are the
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition
columns.
+ :param cluster_by: A list containing the column names of the cluster
columns.
:param schema Explicit Spark SQL schema to materialize this table with.
Supports either a
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
@@ -52,6 +53,7 @@ class Table(Output):
table_properties: Mapping[str, str]
partition_cols: Optional[Sequence[str]]
+ cluster_by: Optional[Sequence[str]]
schema: Optional[Union[StructType, str]]
format: Optional[str]
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index 5c5ef9fc3040..e8a8561c3e74 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -63,6 +63,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
table_details = pb2.PipelineCommand.DefineOutput.TableDetails(
table_properties=output.table_properties,
partition_cols=output.partition_cols,
+ clustering_columns=output.cluster_by,
format=output.format,
# Even though schema_string is not required, the generated
Python code seems to
# erroneously think it is required.
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index d7321fa7cf0c..139de83dc1aa 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -42,7 +42,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9c"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02
\x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutp [...]
+
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcb"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02
\x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutp [...]
)
_globals = globals()
@@ -69,10 +69,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
]._serialized_options = b"8\001"
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options =
b"8\001"
- _globals["_OUTPUTTYPE"]._serialized_start = 6058
- _globals["_OUTPUTTYPE"]._serialized_end = 6163
+ _globals["_OUTPUTTYPE"]._serialized_start = 6105
+ _globals["_OUTPUTTYPE"]._serialized_end = 6210
_globals["_PIPELINECOMMAND"]._serialized_start = 195
- _globals["_PIPELINECOMMAND"]._serialized_end = 4575
+ _globals["_PIPELINECOMMAND"]._serialized_end = 4622
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1129
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1437
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start
= 1338
@@ -80,51 +80,51 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1439
_globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1529
_globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_start = 1532
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_end = 2783
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_end = 2830
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_start =
2068
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_end =
2469
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_end =
2516
_globals[
"_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS_TABLEPROPERTIESENTRY"
- ]._serialized_start = 2382
+ ]._serialized_start = 2429
_globals[
"_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS_TABLEPROPERTIESENTRY"
- ]._serialized_end = 2448
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_start =
2472
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_end =
2681
-
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_start
= 2612
-
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_end
= 2670
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2786
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 3647
+ ]._serialized_end = 2495
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_start =
2519
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_end =
2728
+
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_start
= 2659
+
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_end
= 2717
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2833
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 3694
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start =
1338
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 1396
-
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_start
= 3380
-
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_end
= 3477
- _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 3479
- _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 3537
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 3650
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3972
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
3975
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 4174
-
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
= 4177
-
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
= 4335
-
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start =
4338
- _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end
= 4559
- _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 4578
- _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 5330
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 4947
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
5045
- _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_start =
5048
- _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_end =
5181
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
5184
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 5315
- _globals["_PIPELINEEVENTRESULT"]._serialized_start = 5332
- _globals["_PIPELINEEVENTRESULT"]._serialized_end = 5405
- _globals["_PIPELINEEVENT"]._serialized_start = 5407
- _globals["_PIPELINEEVENT"]._serialized_end = 5523
- _globals["_SOURCECODELOCATION"]._serialized_start = 5526
- _globals["_SOURCECODELOCATION"]._serialized_end = 5767
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5769
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5838
- _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5841
- _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6056
+
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_start
= 3427
+
_globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_end
= 3524
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 3526
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 3584
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 3697
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 4019
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
4022
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 4221
+
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
= 4224
+
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
= 4382
+
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start =
4385
+ _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end
= 4606
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 4625
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 5377
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 4994
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
5092
+ _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_start =
5095
+ _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_end =
5228
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
5231
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 5362
+ _globals["_PIPELINEEVENTRESULT"]._serialized_start = 5379
+ _globals["_PIPELINEEVENTRESULT"]._serialized_end = 5452
+ _globals["_PIPELINEEVENT"]._serialized_start = 5454
+ _globals["_PIPELINEEVENT"]._serialized_end = 5570
+ _globals["_SOURCECODELOCATION"]._serialized_start = 5573
+ _globals["_SOURCECODELOCATION"]._serialized_end = 5814
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5816
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5885
+ _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5888
+ _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6103
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index b9170e763ed9..60d131037c99 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -240,6 +240,7 @@ class PipelineCommand(google.protobuf.message.Message):
FORMAT_FIELD_NUMBER: builtins.int
SCHEMA_DATA_TYPE_FIELD_NUMBER: builtins.int
SCHEMA_STRING_FIELD_NUMBER: builtins.int
+ CLUSTERING_COLUMNS_FIELD_NUMBER: builtins.int
@property
def table_properties(
self,
@@ -255,6 +256,11 @@ class PipelineCommand(google.protobuf.message.Message):
@property
def schema_data_type(self) ->
pyspark.sql.connect.proto.types_pb2.DataType: ...
schema_string: builtins.str
+ @property
+ def clustering_columns(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """Optional cluster columns for the table."""
def __init__(
self,
*,
@@ -263,6 +269,7 @@ class PipelineCommand(google.protobuf.message.Message):
format: builtins.str | None = ...,
schema_data_type: pyspark.sql.connect.proto.types_pb2.DataType
| None = ...,
schema_string: builtins.str = ...,
+ clustering_columns: collections.abc.Iterable[builtins.str] |
None = ...,
) -> None: ...
def HasField(
self,
@@ -284,6 +291,8 @@ class PipelineCommand(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"_format",
b"_format",
+ "clustering_columns",
+ b"clustering_columns",
"format",
b"format",
"partition_cols",
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index c6a5e571f979..0fa36f8a1514 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -104,6 +104,9 @@ message PipelineCommand {
spark.connect.DataType schema_data_type = 4;
string schema_string = 5;
}
+
+ // Optional cluster columns for the table.
+ repeated string clustering_columns = 6;
}
// Metadata that's only applicable to sinks.
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 7e69e546893e..0929b07be523 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -203,6 +203,8 @@ private[connect] object PipelinesHandler extends Logging {
},
partitionCols =
Option(tableDetails.getPartitionColsList.asScala.toSeq)
.filter(_.nonEmpty),
+ clusterCols =
Option(tableDetails.getClusteringColumnsList.asScala.toSeq)
+ .filter(_.nonEmpty),
properties = tableDetails.getTablePropertiesMap.asScala.toMap,
origin = QueryOrigin(
filePath = Option.when(output.getSourceCodeLocation.hasFileName)(
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index 79c34ac46b9f..1a72d112aa2e 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -865,4 +865,41 @@ class PythonPipelineSuite
(exitCode, output.toSeq)
}
+
+ test("empty cluster_by list should work and create table with no
clustering") {
+ withTable("mv", "st") {
+ val graph = buildGraph("""
+ |from pyspark.sql.functions import col
+ |
+ |@dp.materialized_view(cluster_by = [])
+ |def mv():
+ | return spark.range(5).withColumn("id_mod", col("id") % 2)
+ |
+ |@dp.table(cluster_by = [])
+ |def st():
+ | return spark.readStream.table("mv")
+ |""".stripMargin)
+ val updateContext =
+ new PipelineUpdateContextImpl(graph, eventCallback = _ => (),
storageRoot = storageRoot)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ // Check tables are created with no clustering transforms
+ val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+
+ val mvIdentifier = Identifier.of(Array("default"), "mv")
+ val mvTable = catalog.loadTable(mvIdentifier)
+ val mvTransforms = mvTable.partitioning()
+ assert(
+ mvTransforms.isEmpty,
+ s"MaterializedView should have no transforms, but got:
${mvTransforms.mkString(", ")}")
+
+ val stIdentifier = Identifier.of(Array("default"), "st")
+ val stTable = catalog.loadTable(stIdentifier)
+ val stTransforms = stTable.partitioning()
+ assert(
+ stTransforms.isEmpty,
+ s"Table should have no transforms, but got: ${stTransforms.mkString(",
")}")
+ }
+ }
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
index dfb766b1df77..f3b63f791421 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
@@ -41,10 +41,12 @@ class TestPipelineDefinition(graphId: String) {
// TODO: Add support for specifiedSchema
// specifiedSchema: Option[StructType] = None,
partitionCols: Option[Seq[String]] = None,
+ clusterCols: Option[Seq[String]] = None,
properties: Map[String, String] = Map.empty): Unit = {
val tableDetails = sc.PipelineCommand.DefineOutput.TableDetails
.newBuilder()
.addAllPartitionCols(partitionCols.getOrElse(Seq()).asJava)
+ .addAllClusteringColumns(clusterCols.getOrElse(Seq()).asJava)
.putAllTableProperties(properties.asJava)
.build()
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
index cb142988ce51..e5c87fa542ad 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.connector.catalog.{
TableInfo
}
import
org.apache.spark.sql.connector.catalog.CatalogV2Util.v2ColumnsToStructType
-import org.apache.spark.sql.connector.expressions.Expressions
+import org.apache.spark.sql.connector.expressions.{ClusterByTransform,
Expressions}
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.pipelines.graph.QueryOrigin.ExceptionHelpers
import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils.diffSchemas
@@ -266,6 +266,19 @@ object DatasetManager extends Logging {
)
val mergedProperties = resolveTableProperties(table, identifier)
val partitioning =
table.partitionCols.toSeq.flatten.map(Expressions.identity)
+ val clustering = table.clusterCols.map(cols =>
+ ClusterByTransform(cols.map(col => Expressions.column(col)))
+ ).toSeq
+
+ // Validate that partition and cluster columns don't coexist
+ if (partitioning.nonEmpty && clustering.nonEmpty) {
+ throw new AnalysisException(
+ errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED",
+ messageParameters = Map.empty
+ )
+ }
+
+ val allTransforms = partitioning ++ clustering
val existingTableOpt = if (catalog.tableExists(identifier)) {
Some(catalog.loadTable(identifier))
@@ -273,15 +286,15 @@ object DatasetManager extends Logging {
None
}
- // Error if partitioning doesn't match
+ // Error if partitioning/clustering doesn't match
if (existingTableOpt.isDefined) {
- val existingPartitioning = existingTableOpt.get.partitioning().toSeq
- if (existingPartitioning != partitioning) {
+ val existingTransforms = existingTableOpt.get.partitioning().toSeq
+ if (existingTransforms != allTransforms) {
throw new AnalysisException(
errorClass = "CANNOT_UPDATE_PARTITION_COLUMNS",
messageParameters = Map(
- "existingPartitionColumns" -> existingPartitioning.mkString(", "),
- "requestedPartitionColumns" -> partitioning.mkString(", ")
+ "existingPartitionColumns" -> existingTransforms.mkString(", "),
+ "requestedPartitionColumns" -> allTransforms.mkString(", ")
)
)
}
@@ -314,7 +327,7 @@ object DatasetManager extends Logging {
new TableInfo.Builder()
.withProperties(mergedProperties.asJava)
.withColumns(CatalogV2Util.structTypeToV2Columns(outputSchema))
- .withPartitions(partitioning.toArray)
+ .withPartitions(allTransforms.toArray)
.build()
)
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
index 2c9029fdd34d..647df79bb940 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
@@ -264,12 +264,20 @@ class BatchTableWrite(
if (destination.format.isDefined) {
dataFrameWriter.format(destination.format.get)
}
+
+ // In "append" mode with saveAsTable, partition/cluster columns must
be specified in query
+ // because the format and options of the existing table is used, and
the table could
+ // have been created with partition columns.
+ if (destination.clusterCols.isDefined) {
+ val clusterCols = destination.clusterCols.get
+ dataFrameWriter.clusterBy(clusterCols.head, clusterCols.tail: _*)
+ }
+ if (destination.partitionCols.isDefined) {
+ dataFrameWriter.partitionBy(destination.partitionCols.get: _*)
+ }
+
dataFrameWriter
.mode("append")
- // In "append" mode with saveAsTable, partition columns must be
specified in query
- // because the format and options of the existing table is used, and
the table could
- // have been created with partition columns.
- .partitionBy(destination.partitionCols.getOrElse(Seq.empty): _*)
.saveAsTable(destination.identifier.unquotedString)
}
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
index 55a03a2d19f9..5df12be7f4cf 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
@@ -192,6 +192,7 @@ class SqlGraphRegistrationContext(
specifiedSchema =
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
partitionCols =
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
+ clusterCols = None,
properties = cst.tableSpec.properties,
origin = queryOrigin.copy(
objectName = Option(stIdentifier.unquotedString),
@@ -223,6 +224,7 @@ class SqlGraphRegistrationContext(
specifiedSchema =
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
partitionCols =
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
+ clusterCols = None,
properties = cst.tableSpec.properties,
origin = queryOrigin.copy(
objectName = Option(stIdentifier.unquotedString),
@@ -273,6 +275,7 @@ class SqlGraphRegistrationContext(
specifiedSchema =
Option.when(cmv.columns.nonEmpty)(StructType(cmv.columns.map(_.toV1Column))),
partitionCols =
Option(PartitionHelper.applyPartitioning(cmv.partitioning, queryOrigin)),
+ clusterCols = None,
properties = cmv.tableSpec.properties,
origin = queryOrigin.copy(
objectName = Option(mvIdentifier.unquotedString),
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
index 87e01ed2021e..c762174e6725 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
@@ -114,6 +114,7 @@ sealed trait TableInput extends Input {
* @param identifier The identifier of this table within the graph.
* @param specifiedSchema The user-specified schema for this table.
* @param partitionCols What columns the table should be partitioned by when
materialized.
+ * @param clusterCols What columns the table should be clustered by when
materialized.
* @param normalizedPath Normalized storage location for the table based on
the user-specified table
* path (if not defined, we will normalize a managed
storage path for it).
* @param properties Table Properties to set in table metadata.
@@ -124,6 +125,7 @@ case class Table(
identifier: TableIdentifier,
specifiedSchema: Option[StructType],
partitionCols: Option[Seq[String]],
+ clusterCols: Option[Seq[String]],
normalizedPath: Option[String],
properties: Map[String, String] = Map.empty,
comment: Option[String],
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index 31afc5a27a54..ba8419eb6e9c 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.pipelines.graph
import scala.jdk.CollectionConverters._
import org.apache.spark.SparkThrowable
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier,
TableCatalog}
-import org.apache.spark.sql.connector.expressions.Expressions
+import org.apache.spark.sql.connector.expressions.{ClusterByTransform,
Expressions, FieldReference}
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
import
org.apache.spark.sql.pipelines.graph.DatasetManager.TableMaterializationException
import org.apache.spark.sql.pipelines.utils.{BaseCoreExecutionTest,
TestGraphRegistrationContext}
@@ -885,4 +886,247 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
storageRoot = storageRoot
)
}
+
+ test("cluster columns with user schema") {
+ val session = spark
+ import session.implicits._
+
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "a",
+ query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1",
"x2", "x3"))),
+ specifiedSchema = Option(
+ new StructType()
+ .add("x1", IntegerType)
+ .add("x2", IntegerType)
+ .add("x3", StringType)
+ ),
+ clusterCols = Option(Seq("x1", "x3"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+ val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")
+ val table = catalog.loadTable(identifier)
+ assert(
+ table.columns() sameElements CatalogV2Util.structTypeToV2Columns(
+ new StructType()
+ .add("x1", IntegerType)
+ .add("x2", IntegerType)
+ .add("x3", StringType)
+ )
+ )
+ val expectedClusterTransform = ClusterByTransform(
+ Seq(FieldReference("x1"), FieldReference("x3")).toSeq
+ )
+ assert(table.partitioning().contains(expectedClusterTransform))
+ }
+
+ test("specifying cluster column with existing clustered table") {
+ val session = spark
+ import session.implicits._
+
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "t10",
+ query = Option(dfFlowFunc(Seq((1, true, "a"), (2, false,
"b")).toDF("x", "y", "z"))),
+ clusterCols = Option(Seq("x", "z"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+
+ val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+ val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t10")
+ val table = catalog.loadTable(identifier)
+ val expectedClusterTransform = ClusterByTransform(
+ Seq(FieldReference("x"), FieldReference("z")).toSeq
+ )
+ assert(table.partitioning().contains(expectedClusterTransform))
+
+ // Specify the same cluster columns - should work
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerFlow(
+ "t10",
+ "t10",
+ query = dfFlowFunc(Seq((3, true, "c"), (4, false, "d")).toDF("x",
"y", "z"))
+ )
+ registerTable("t10", clusterCols = Option(Seq("x", "z")))
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+
+ val table2 = catalog.loadTable(identifier)
+ assert(table2.partitioning().contains(expectedClusterTransform))
+
+ // Don't specify cluster columns when table already has them - should throw
+ val ex = intercept[TableMaterializationException] {
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerFlow(
+ "t10",
+ "t10",
+ query = dfFlowFunc(Seq((5, true, "e"), (6, false, "f")).toDF("x",
"y", "z"))
+ )
+ registerTable("t10")
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ }
+ assert(ex.cause.asInstanceOf[SparkThrowable].getCondition ==
"CANNOT_UPDATE_PARTITION_COLUMNS")
+ }
+
+ test("specifying cluster column different from existing clustered table") {
+ val session = spark
+ import session.implicits._
+
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "t11",
+ query = Option(dfFlowFunc(Seq((1, true, "a"), (2, false,
"b")).toDF("x", "y", "z"))),
+ clusterCols = Option(Seq("x"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+
+ val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+ val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t11")
+
+ // Specify different cluster columns - should throw
+ val ex = intercept[TableMaterializationException] {
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerFlow(
+ "t11",
+ "t11",
+ query = dfFlowFunc(Seq((3, true, "c"), (4, false, "d")).toDF("x",
"y", "z"))
+ )
+ registerTable("t11", clusterCols = Option(Seq("y")))
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ }
+ assert(ex.cause.asInstanceOf[SparkThrowable].getCondition ==
"CANNOT_UPDATE_PARTITION_COLUMNS")
+
+ val table = catalog.loadTable(identifier)
+ val expectedClusterTransform =
ClusterByTransform(Seq(FieldReference("x")).toSeq)
+ assert(table.partitioning().contains(expectedClusterTransform))
+ }
+
+ test("cluster columns only (no partitioning)") {
+ val session = spark
+ import session.implicits._
+
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "t12",
+ query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1",
"x2", "x3"))),
+ specifiedSchema = Option(
+ new StructType()
+ .add("x1", IntegerType)
+ .add("x2", IntegerType)
+ .add("x3", StringType)
+ ),
+ clusterCols = Option(Seq("x1", "x3"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+ val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t12")
+ val table = catalog.loadTable(identifier)
+ assert(
+ table.columns() sameElements CatalogV2Util.structTypeToV2Columns(
+ new StructType()
+ .add("x1", IntegerType)
+ .add("x2", IntegerType)
+ .add("x3", StringType)
+ )
+ )
+
+ val transforms = table.partitioning()
+ val expectedClusterTransform = ClusterByTransform(
+ Seq(FieldReference("x1"), FieldReference("x3")).toSeq
+ )
+ assert(transforms.contains(expectedClusterTransform))
+ }
+
+ test("materialized view with cluster columns") {
+ val session = spark
+ import session.implicits._
+
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerMaterializedView(
+ "mv1",
+ query = dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1", "x2",
"x3")),
+ clusterCols = Option(Seq("x1", "x2"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+ val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "mv1")
+ val table = catalog.loadTable(identifier)
+ assert(
+ table.columns() sameElements CatalogV2Util.structTypeToV2Columns(
+ new StructType()
+ .add("x1", IntegerType)
+ .add("x2", IntegerType)
+ .add("x3", StringType)
+ )
+ )
+ val expectedClusterTransform = ClusterByTransform(
+ Seq(FieldReference("x1"), FieldReference("x2")).toSeq
+ )
+ assert(table.partitioning().contains(expectedClusterTransform))
+ }
+
+ test("partition and cluster columns together should fail") {
+ val session = spark
+ import session.implicits._
+
+ val ex = intercept[TableMaterializationException] {
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "invalid_table",
+ query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1",
"x2", "x3"))),
+ partitionCols = Option(Seq("x2")),
+ clusterCols = Option(Seq("x1", "x3"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ }
+ assert(ex.cause.isInstanceOf[AnalysisException])
+ val analysisEx = ex.cause.asInstanceOf[AnalysisException]
+ assert(analysisEx.errorClass.get ==
"SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED")
+ }
+
+ test("cluster column that doesn't exist in table schema should fail") {
+ val session = spark
+ import session.implicits._
+
+ val ex = intercept[TableMaterializationException] {
+ materializeGraph(
+ new TestGraphRegistrationContext(spark) {
+ registerTable(
+ "invalid_cluster_table",
+ query = Option(dfFlowFunc(Seq((1, 1, "x"), (2, 3, "y")).toDF("x1",
"x2", "x3"))),
+ clusterCols = Option(Seq("nonexistent_column"))
+ )
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
+ )
+ }
+ assert(ex.cause.isInstanceOf[AnalysisException])
+ }
}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
index bb7c8e833f84..c6b457ee04eb 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
@@ -542,6 +542,66 @@ trait APITest
}
}
+ test("Python Pipeline with cluster columns") {
+ val pipelineSpec =
+ TestPipelineSpec(include = Seq("transformations/**"))
+ val pipelineConfig = TestPipelineConfiguration(pipelineSpec)
+ val sources = Seq(
+ PipelineSourceFile(
+ name = "transformations/definition.py",
+ contents = """
+ |from pyspark import pipelines as dp
+ |from pyspark.sql import DataFrame, SparkSession
+ |from pyspark.sql.functions import col
+ |
+ |spark = SparkSession.active()
+ |
+ |@dp.materialized_view(cluster_by = ["cluster_col1"])
+ |def mv():
+ | df = spark.range(10)
+ | df = df.withColumn("cluster_col1", col("id") % 3)
+ | df = df.withColumn("cluster_col2", col("id") % 2)
+ | return df
+ |
+ |@dp.table(cluster_by = ["cluster_col1"])
+ |def st():
+ | return spark.readStream.table("mv")
+ |""".stripMargin))
+ val pipeline = createAndRunPipeline(pipelineConfig, sources)
+ awaitPipelineTermination(pipeline)
+
+ // Verify tables have correct data
+ Seq("mv", "st").foreach { tbl =>
+ val fullName = s"$tbl"
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $fullName ORDER BY id"),
+ Seq(
+ Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 0), Row(3, 0, 1), Row(4, 1, 0),
+ Row(5, 2, 1), Row(6, 0, 0), Row(7, 1, 1), Row(8, 2, 0), Row(9, 0, 1)
+ ))
+ }
+
+ // Verify clustering information is stored in catalog
+ val catalog = spark.sessionState.catalogManager.currentCatalog
+ .asInstanceOf[org.apache.spark.sql.connector.catalog.TableCatalog]
+ // Check materialized view has clustering transform
+ val mvIdentifier = org.apache.spark.sql.connector.catalog.Identifier
+ .of(Array("default"), "mv")
+ val mvTable = catalog.loadTable(mvIdentifier)
+ val mvTransforms = mvTable.partitioning()
+ assert(mvTransforms.length == 1)
+ assert(mvTransforms.head.name() == "cluster_by")
+ assert(mvTransforms.head.toString.contains("cluster_col1"))
+ // Check streaming table has clustering transform
+ val stIdentifier = org.apache.spark.sql.connector.catalog.Identifier
+ .of(Array("default"), "st")
+ val stTable = catalog.loadTable(stIdentifier)
+ val stTransforms = stTable.partitioning()
+ assert(stTransforms.length == 1)
+ assert(stTransforms.head.name() == "cluster_by")
+ assert(stTransforms.head.toString.contains("cluster_col1"))
+ }
+
/* Below tests pipeline execution configurations */
test("Pipeline with dry run") {
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index 599aab87d1f7..d88432d68ca3 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -46,6 +46,7 @@ class TestGraphRegistrationContext(
comment: Option[String] = None,
specifiedSchema: Option[StructType] = None,
partitionCols: Option[Seq[String]] = None,
+ clusterCols: Option[Seq[String]] = None,
properties: Map[String, String] = Map.empty,
baseOrigin: QueryOrigin = QueryOrigin.empty,
format: Option[String] = None,
@@ -58,6 +59,7 @@ class TestGraphRegistrationContext(
comment,
specifiedSchema,
partitionCols,
+ clusterCols,
properties,
baseOrigin,
format,
@@ -99,6 +101,7 @@ class TestGraphRegistrationContext(
comment: Option[String] = None,
specifiedSchema: Option[StructType] = None,
partitionCols: Option[Seq[String]] = None,
+ clusterCols: Option[Seq[String]] = None,
properties: Map[String, String] = Map.empty,
baseOrigin: QueryOrigin = QueryOrigin.empty,
format: Option[String] = None,
@@ -111,6 +114,7 @@ class TestGraphRegistrationContext(
comment,
specifiedSchema,
partitionCols,
+ clusterCols,
properties,
baseOrigin,
format,
@@ -129,6 +133,7 @@ class TestGraphRegistrationContext(
comment: Option[String],
specifiedSchema: Option[StructType],
partitionCols: Option[Seq[String]],
+ clusterCols: Option[Seq[String]],
properties: Map[String, String],
baseOrigin: QueryOrigin,
format: Option[String],
@@ -150,6 +155,7 @@ class TestGraphRegistrationContext(
comment = comment,
specifiedSchema = specifiedSchema,
partitionCols = partitionCols,
+ clusterCols = clusterCols,
properties = properties,
origin = baseOrigin.merge(
QueryOrigin(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]