This is an automated email from the ASF dual-hosted git repository.
sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e56ab2fdaea0 [SPARK-53593][SDP] Add response field for DefineDataset
and DefineFlow RPC
e56ab2fdaea0 is described below
commit e56ab2fdaea0946e323367f8924a99ca948f53d0
Author: Jessie Luo <[email protected]>
AuthorDate: Fri Sep 26 12:48:17 2025 -0700
[SPARK-53593][SDP] Add response field for DefineDataset and DefineFlow RPC
### What changes were proposed in this pull request?
This PR updates the Spark Connect server to return resolved dataset and
flow names in the responses of DefineDataset and DefineFlow RPCs.
Changes include:
1. Adding resolved_data_name and resolved_flow_name to the respective proto
response messages.
2. Updating the RPC handlers to return resolved identifiers as response.
3. Adding unit tests in SparkDeclarativePipelinesServerSuite to validate
the resolved names
### Why are the changes needed?
The SC client requires the resolved names for datasets and flows to support
graph resolution in the LDP frontend. Returning this info from the server
ensures consistent naming and proper registration.
### Does this PR introduce _any_ user-facing change?
Yes. The DefineDataset and DefineFlow RPC responses now include fully
qualified names like `catalog`.`db`.`mv`. Implicit flows to temp views return
unqualified names like `mv`.
### How was this patch tested?
Added targeted unit tests in SparkDeclarativePipelinesServerSuite. Verified
both default and custom catalog/database cases.
Run test using
```
./build/sbt
> project connect
> testOnly *SparkDeclarativePipelinesServerSuite
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52328 from cookiedough77/jessie.luo_data/spark-add-response.
Lead-authored-by: Jessie Luo <[email protected]>
Co-authored-by: cookiedough77
<[email protected]>
Signed-off-by: Sandy Ryza <[email protected]>
---
python/pyspark/sql/connect/proto/common_pb2.py | 28 +-
python/pyspark/sql/connect/proto/common_pb2.pyi | 28 ++
python/pyspark/sql/connect/proto/pipelines_pb2.py | 69 ++---
python/pyspark/sql/connect/proto/pipelines_pb2.pyi | 130 ++++++---
.../apache/spark/sql/catalyst/identifiers.scala | 18 +-
.../src/main/protobuf/spark/connect/common.proto | 6 +
.../main/protobuf/spark/connect/pipelines.proto | 16 +-
.../sql/connect/pipelines/PipelinesHandler.scala | 110 ++++++--
.../SparkDeclarativePipelinesServerSuite.scala | 294 ++++++++++++++++++++-
.../pipelines/graph/GraphRegistrationContext.scala | 78 +-----
.../utils/TestGraphRegistrationContext.scala | 54 +++-
11 files changed, 635 insertions(+), 196 deletions(-)
diff --git a/python/pyspark/sql/connect/proto/common_pb2.py
b/python/pyspark/sql/connect/proto/common_pb2.py
index 4eaed50598e1..b761a1f5ccf6 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.py
+++ b/python/pyspark/sql/connect/proto/common_pb2.py
@@ -35,7 +35,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01
\x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12
\n\x0cuse_off_heap\x18\x03
\x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04
\x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05
\x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01
\x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3 [...]
+
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01
\x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12
\n\x0cuse_off_heap\x18\x03
\x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04
\x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05
\x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01
\x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3 [...]
)
_globals = globals()
@@ -74,16 +74,18 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_JVMORIGIN"]._serialized_end = 1660
_globals["_STACKTRACEELEMENT"]._serialized_start = 1663
_globals["_STACKTRACEELEMENT"]._serialized_end = 2025
- _globals["_BOOLS"]._serialized_start = 2027
- _globals["_BOOLS"]._serialized_end = 2058
- _globals["_INTS"]._serialized_start = 2060
- _globals["_INTS"]._serialized_end = 2090
- _globals["_LONGS"]._serialized_start = 2092
- _globals["_LONGS"]._serialized_end = 2123
- _globals["_FLOATS"]._serialized_start = 2125
- _globals["_FLOATS"]._serialized_end = 2157
- _globals["_DOUBLES"]._serialized_start = 2159
- _globals["_DOUBLES"]._serialized_end = 2192
- _globals["_STRINGS"]._serialized_start = 2194
- _globals["_STRINGS"]._serialized_end = 2227
+ _globals["_RESOLVEDIDENTIFIER"]._serialized_start = 2027
+ _globals["_RESOLVEDIDENTIFIER"]._serialized_end = 2143
+ _globals["_BOOLS"]._serialized_start = 2145
+ _globals["_BOOLS"]._serialized_end = 2176
+ _globals["_INTS"]._serialized_start = 2178
+ _globals["_INTS"]._serialized_end = 2208
+ _globals["_LONGS"]._serialized_start = 2210
+ _globals["_LONGS"]._serialized_end = 2241
+ _globals["_FLOATS"]._serialized_start = 2243
+ _globals["_FLOATS"]._serialized_end = 2275
+ _globals["_DOUBLES"]._serialized_start = 2277
+ _globals["_DOUBLES"]._serialized_end = 2310
+ _globals["_STRINGS"]._serialized_start = 2312
+ _globals["_STRINGS"]._serialized_end = 2345
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/common_pb2.pyi
b/python/pyspark/sql/connect/proto/common_pb2.pyi
index 8111cfed10cd..95addf6589d1 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/common_pb2.pyi
@@ -599,6 +599,34 @@ class StackTraceElement(google.protobuf.message.Message):
global___StackTraceElement = StackTraceElement
+class ResolvedIdentifier(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ CATALOG_NAME_FIELD_NUMBER: builtins.int
+ NAMESPACE_FIELD_NUMBER: builtins.int
+ TABLE_NAME_FIELD_NUMBER: builtins.int
+ catalog_name: builtins.str
+ @property
+ def namespace(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
...
+ table_name: builtins.str
+ def __init__(
+ self,
+ *,
+ catalog_name: builtins.str = ...,
+ namespace: collections.abc.Iterable[builtins.str] | None = ...,
+ table_name: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "catalog_name", b"catalog_name", "namespace", b"namespace",
"table_name", b"table_name"
+ ],
+ ) -> None: ...
+
+global___ResolvedIdentifier = ResolvedIdentifier
+
class Bools(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index 08b39a39e831..1f3155646d62 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -35,12 +35,13 @@ _sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as
google_dot_protobuf_dot_timestamp__pb2
+from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common__pb2
from pyspark.sql.connect.proto import relations_pb2 as
spark_dot_connect_dot_relations__pb2
from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x97\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_f
[...]
+
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xc4\x13\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
)
_globals = globals()
@@ -59,36 +60,38 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options
= b"8\001"
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options =
b"8\001"
- _globals["_DATASETTYPE"]._serialized_start = 3191
- _globals["_DATASETTYPE"]._serialized_end = 3288
- _globals["_PIPELINECOMMAND"]._serialized_start = 140
- _globals["_PIPELINECOMMAND"]._serialized_end = 2723
- _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 719
- _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1110
-
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start
= 928
-
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_end =
986
-
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_RESPONSE"]._serialized_start =
988
- _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_RESPONSE"]._serialized_end
= 1069
- _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1112
- _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1202
- _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_start = 1205
- _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_end = 1798
-
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_start
= 1642
-
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_end
= 1708
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 1801
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2223
- _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start =
928
- _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 986
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2226
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2505
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
2508
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2707
- _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2726
- _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2996
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 2883
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
2981
- _globals["_PIPELINEEVENTRESULT"]._serialized_start = 2998
- _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3071
- _globals["_PIPELINEEVENT"]._serialized_start = 3073
- _globals["_PIPELINEEVENT"]._serialized_end = 3189
+ _globals["_DATASETTYPE"]._serialized_start = 3622
+ _globals["_DATASETTYPE"]._serialized_end = 3719
+ _globals["_PIPELINECOMMAND"]._serialized_start = 168
+ _globals["_PIPELINECOMMAND"]._serialized_end = 2668
+ _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 747
+ _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1055
+
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start
= 956
+
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_end =
1014
+ _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1057
+ _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1147
+ _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_start = 1150
+ _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_end = 1743
+
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_start
= 1587
+
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_end
= 1653
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 1746
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2168
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start =
956
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 1014
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2171
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2450
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
2453
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2652
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2671
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 3427
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 3043
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
3141
+ _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start =
3144
+ _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end =
3278
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
3281
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 3412
+ _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3429
+ _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3502
+ _globals["_PIPELINEEVENT"]._serialized_start = 3504
+ _globals["_PIPELINEEVENT"]._serialized_end = 3620
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index 6287aabafc6b..a174b4d28293 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -40,6 +40,7 @@ import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import google.protobuf.timestamp_pb2
+import pyspark.sql.connect.proto.common_pb2
import pyspark.sql.connect.proto.relations_pb2
import pyspark.sql.connect.proto.types_pb2
import sys
@@ -110,40 +111,6 @@ class PipelineCommand(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["key", b"key",
"value", b"value"]
) -> None: ...
- class Response(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int
- dataflow_graph_id: builtins.str
- """The ID of the created graph."""
- def __init__(
- self,
- *,
- dataflow_graph_id: builtins.str | None = ...,
- ) -> None: ...
- def HasField(
- self,
- field_name: typing_extensions.Literal[
- "_dataflow_graph_id",
- b"_dataflow_graph_id",
- "dataflow_graph_id",
- b"dataflow_graph_id",
- ],
- ) -> builtins.bool: ...
- def ClearField(
- self,
- field_name: typing_extensions.Literal[
- "_dataflow_graph_id",
- b"_dataflow_graph_id",
- "dataflow_graph_id",
- b"dataflow_graph_id",
- ],
- ) -> None: ...
- def WhichOneof(
- self,
- oneof_group: typing_extensions.Literal["_dataflow_graph_id",
b"_dataflow_graph_id"],
- ) -> typing_extensions.Literal["dataflow_graph_id"] | None: ...
-
DEFAULT_CATALOG_FIELD_NUMBER: builtins.int
DEFAULT_DATABASE_FIELD_NUMBER: builtins.int
SQL_CONF_FIELD_NUMBER: builtins.int
@@ -787,22 +754,106 @@ class
PipelineCommandResult(google.protobuf.message.Message):
oneof_group: typing_extensions.Literal["_dataflow_graph_id",
b"_dataflow_graph_id"],
) -> typing_extensions.Literal["dataflow_graph_id"] | None: ...
+ class DefineDatasetResult(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ RESOLVED_IDENTIFIER_FIELD_NUMBER: builtins.int
+ @property
+ def resolved_identifier(self) ->
pyspark.sql.connect.proto.common_pb2.ResolvedIdentifier:
+ """Resolved identifier of the dataset"""
+ def __init__(
+ self,
+ *,
+ resolved_identifier:
pyspark.sql.connect.proto.common_pb2.ResolvedIdentifier
+ | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_resolved_identifier",
+ b"_resolved_identifier",
+ "resolved_identifier",
+ b"resolved_identifier",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_resolved_identifier",
+ b"_resolved_identifier",
+ "resolved_identifier",
+ b"resolved_identifier",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal["_resolved_identifier",
b"_resolved_identifier"],
+ ) -> typing_extensions.Literal["resolved_identifier"] | None: ...
+
+ class DefineFlowResult(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ RESOLVED_IDENTIFIER_FIELD_NUMBER: builtins.int
+ @property
+ def resolved_identifier(self) ->
pyspark.sql.connect.proto.common_pb2.ResolvedIdentifier:
+ """Resolved identifier of the flow"""
+ def __init__(
+ self,
+ *,
+ resolved_identifier:
pyspark.sql.connect.proto.common_pb2.ResolvedIdentifier
+ | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_resolved_identifier",
+ b"_resolved_identifier",
+ "resolved_identifier",
+ b"resolved_identifier",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_resolved_identifier",
+ b"_resolved_identifier",
+ "resolved_identifier",
+ b"resolved_identifier",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal["_resolved_identifier",
b"_resolved_identifier"],
+ ) -> typing_extensions.Literal["resolved_identifier"] | None: ...
+
CREATE_DATAFLOW_GRAPH_RESULT_FIELD_NUMBER: builtins.int
+ DEFINE_DATASET_RESULT_FIELD_NUMBER: builtins.int
+ DEFINE_FLOW_RESULT_FIELD_NUMBER: builtins.int
@property
def create_dataflow_graph_result(
self,
) -> global___PipelineCommandResult.CreateDataflowGraphResult: ...
+ @property
+ def define_dataset_result(self) ->
global___PipelineCommandResult.DefineDatasetResult: ...
+ @property
+ def define_flow_result(self) ->
global___PipelineCommandResult.DefineFlowResult: ...
def __init__(
self,
*,
create_dataflow_graph_result:
global___PipelineCommandResult.CreateDataflowGraphResult
| None = ...,
+ define_dataset_result:
global___PipelineCommandResult.DefineDatasetResult | None = ...,
+ define_flow_result: global___PipelineCommandResult.DefineFlowResult |
None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"create_dataflow_graph_result",
b"create_dataflow_graph_result",
+ "define_dataset_result",
+ b"define_dataset_result",
+ "define_flow_result",
+ b"define_flow_result",
"result_type",
b"result_type",
],
@@ -812,13 +863,22 @@ class
PipelineCommandResult(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"create_dataflow_graph_result",
b"create_dataflow_graph_result",
+ "define_dataset_result",
+ b"define_dataset_result",
+ "define_flow_result",
+ b"define_flow_result",
"result_type",
b"result_type",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["result_type",
b"result_type"]
- ) -> typing_extensions.Literal["create_dataflow_graph_result"] | None: ...
+ ) -> (
+ typing_extensions.Literal[
+ "create_dataflow_graph_result", "define_dataset_result",
"define_flow_result"
+ ]
+ | None
+ ): ...
global___PipelineCommandResult = PipelineCommandResult
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
index ceced9313940..625a3272d11b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
@@ -36,17 +36,17 @@ sealed trait CatalystIdentifier {
*/
private def quoteIdentifier(name: String): String = name.replace("`", "``")
+ def resolvedId: String = quoteIdentifier(identifier)
+ def resolvedDb: Option[String] = database.map(quoteIdentifier)
+ def resolvedCatalog: Option[String] = catalog.map(quoteIdentifier)
+
def quotedString: String = {
- val replacedId = quoteIdentifier(identifier)
- val replacedDb = database.map(quoteIdentifier)
- val replacedCatalog = catalog.map(quoteIdentifier)
-
- if (replacedCatalog.isDefined && replacedDb.isDefined) {
- s"`${replacedCatalog.get}`.`${replacedDb.get}`.`$replacedId`"
- } else if (replacedDb.isDefined) {
- s"`${replacedDb.get}`.`$replacedId`"
+ if (resolvedCatalog.isDefined && resolvedDb.isDefined) {
+ s"`${resolvedCatalog.get}`.`${resolvedDb.get}`.`$resolvedId`"
+ } else if (resolvedDb.isDefined) {
+ s"`${resolvedDb.get}`.`$resolvedId`"
} else {
- s"`$replacedId`"
+ s"`$resolvedId`"
}
}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/common.proto
b/sql/connect/common/src/main/protobuf/spark/connect/common.proto
index 7fab95fa1c9a..c5470538c193 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/common.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/common.proto
@@ -148,6 +148,12 @@ message StackTraceElement {
int32 line_number = 7;
}
+message ResolvedIdentifier {
+ string catalog_name = 1;
+ repeated string namespace = 2;
+ string table_name = 3;
+}
+
message Bools {
repeated bool values = 1;
}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index 57e1ffc7dbe7..f06d9acbaab1 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -20,6 +20,7 @@ syntax = "proto3";
package spark.connect;
import "google/protobuf/timestamp.proto";
+import "spark/connect/common.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";
@@ -48,11 +49,6 @@ message PipelineCommand {
// SQL configurations for all flows in this graph.
map<string, string> sql_conf = 5;
-
- message Response {
- // The ID of the created graph.
- optional string dataflow_graph_id = 1;
- }
}
// Drops the graph and stops any running attached flows.
@@ -146,11 +142,21 @@ message PipelineCommand {
message PipelineCommandResult {
oneof result_type {
CreateDataflowGraphResult create_dataflow_graph_result = 1;
+ DefineDatasetResult define_dataset_result = 2;
+ DefineFlowResult define_flow_result = 3;
}
message CreateDataflowGraphResult {
// The ID of the created graph.
optional string dataflow_graph_id = 1;
}
+ message DefineDatasetResult {
+ // Resolved identifier of the dataset
+ optional ResolvedIdentifier resolved_identifier = 1;
+ }
+ message DefineFlowResult {
+ // Resolved identifier of the flow
+ optional ResolvedIdentifier resolved_identifier = 1;
+ }
}
// The type of dataset.
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index f01b9cfb8f09..4ff9818d13d9 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -23,7 +23,7 @@ import scala.util.Using
import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{ExecutePlanResponse,
PipelineCommandResult, Relation}
+import org.apache.spark.connect.proto.{ExecutePlanResponse,
PipelineCommandResult, Relation, ResolvedIdentifier}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
@@ -81,12 +81,42 @@ private[connect] object PipelinesHandler extends Logging {
defaultResponse
case proto.PipelineCommand.CommandTypeCase.DEFINE_DATASET =>
logInfo(s"Define pipelines dataset cmd received: $cmd")
- defineDataset(cmd.getDefineDataset, sessionHolder)
- defaultResponse
+ val resolvedDataset =
+ defineDataset(cmd.getDefineDataset, sessionHolder)
+ val identifierBuilder = ResolvedIdentifier.newBuilder()
+
resolvedDataset.resolvedCatalog.foreach(identifierBuilder.setCatalogName)
+ resolvedDataset.resolvedDb.foreach { ns =>
+ identifierBuilder.addNamespace(ns)
+ }
+ identifierBuilder.setTableName(resolvedDataset.resolvedId)
+ val identifier = identifierBuilder.build()
+ PipelineCommandResult
+ .newBuilder()
+ .setDefineDatasetResult(
+ PipelineCommandResult.DefineDatasetResult
+ .newBuilder()
+ .setResolvedIdentifier(identifier)
+ .build())
+ .build()
case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW =>
logInfo(s"Define pipelines flow cmd received: $cmd")
- defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder)
- defaultResponse
+ val resolvedFlow =
+ defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder)
+ val identifierBuilder = ResolvedIdentifier.newBuilder()
+ resolvedFlow.resolvedCatalog.foreach(identifierBuilder.setCatalogName)
+ resolvedFlow.resolvedDb.foreach { ns =>
+ identifierBuilder.addNamespace(ns)
+ }
+ identifierBuilder.setTableName(resolvedFlow.resolvedId)
+ val identifier = identifierBuilder.build()
+ PipelineCommandResult
+ .newBuilder()
+ .setDefineFlowResult(
+ PipelineCommandResult.DefineFlowResult
+ .newBuilder()
+ .setResolvedIdentifier(identifierBuilder)
+ .build())
+ .build()
case proto.PipelineCommand.CommandTypeCase.START_RUN =>
logInfo(s"Start pipeline cmd received: $cmd")
startRun(cmd.getStartRun, responseObserver, sessionHolder)
@@ -140,20 +170,23 @@ private[connect] object PipelinesHandler extends Logging {
private def defineDataset(
dataset: proto.PipelineCommand.DefineDataset,
- sessionHolder: SessionHolder): Unit = {
+ sessionHolder: SessionHolder): TableIdentifier = {
val dataflowGraphId = dataset.getDataflowGraphId
val graphElementRegistry =
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
dataset.getDatasetType match {
case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE =>
- val tableIdentifier =
- GraphIdentifierManager.parseTableIdentifier(
- dataset.getDatasetName,
- sessionHolder.session)
+ val qualifiedIdentifier = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = GraphIdentifierManager
+ .parseTableIdentifier(dataset.getDatasetName,
sessionHolder.session),
+ currentCatalog = Some(graphElementRegistry.defaultCatalog),
+ currentDatabase = Some(graphElementRegistry.defaultDatabase))
+ .identifier
graphElementRegistry.registerTable(
Table(
- identifier = tableIdentifier,
+ identifier = qualifiedIdentifier,
comment = Option(dataset.getComment),
specifiedSchema = Option.when(dataset.hasSchema)(
DataTypeProtoConverter
@@ -164,17 +197,16 @@ private[connect] object PipelinesHandler extends Logging {
properties = dataset.getTablePropertiesMap.asScala.toMap,
baseOrigin = QueryOrigin(
objectType = Option(QueryOriginType.Table.toString),
- objectName = Option(tableIdentifier.unquotedString),
+ objectName = Option(qualifiedIdentifier.unquotedString),
language = Option(Python())),
format = Option.when(dataset.hasFormat)(dataset.getFormat),
normalizedPath = None,
isStreamingTable = dataset.getDatasetType ==
proto.DatasetType.TABLE))
+ qualifiedIdentifier
case proto.DatasetType.TEMPORARY_VIEW =>
- val viewIdentifier =
- GraphIdentifierManager.parseTableIdentifier(
- dataset.getDatasetName,
- sessionHolder.session)
-
+ val viewIdentifier = GraphIdentifierManager
+ .parseAndValidateTemporaryViewIdentifier(rawViewIdentifier =
GraphIdentifierManager
+ .parseTableIdentifier(dataset.getDatasetName,
sessionHolder.session))
graphElementRegistry.registerView(
TemporaryView(
identifier = viewIdentifier,
@@ -185,6 +217,7 @@ private[connect] object PipelinesHandler extends Logging {
language = Option(Python())),
properties = Map.empty,
sqlText = None))
+ viewIdentifier
case _ =>
throw new IllegalArgumentException(s"Unknown dataset type:
${dataset.getDatasetType}")
}
@@ -193,40 +226,65 @@ private[connect] object PipelinesHandler extends Logging {
private def defineFlow(
flow: proto.PipelineCommand.DefineFlow,
transformRelationFunc: Relation => LogicalPlan,
- sessionHolder: SessionHolder): Unit = {
+ sessionHolder: SessionHolder): TableIdentifier = {
val dataflowGraphId = flow.getDataflowGraphId
val graphElementRegistry =
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
+ val defaultCatalog = graphElementRegistry.defaultCatalog
+ val defaultDatabase = graphElementRegistry.defaultDatabase
val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName
-
- val flowIdentifier = GraphIdentifierManager
+ val rawFlowIdentifier = GraphIdentifierManager
.parseTableIdentifier(name = flow.getFlowName, spark =
sessionHolder.session)
// If the flow is not an implicit flow (i.e. one defined as part of
dataset creation), then
// it must be a single-part identifier.
- if (!isImplicitFlow &&
!IdentifierHelper.isSinglePartIdentifier(flowIdentifier)) {
+ if (!isImplicitFlow &&
!IdentifierHelper.isSinglePartIdentifier(rawFlowIdentifier)) {
throw new AnalysisException(
"MULTIPART_FLOW_NAME_NOT_SUPPORTED",
Map("flowName" -> flow.getFlowName))
}
+ val rawDestinationIdentifier = GraphIdentifierManager
+ .parseTableIdentifier(name = flow.getTargetDatasetName, spark =
sessionHolder.session)
+ val flowWritesToView =
+ graphElementRegistry
+ .getViews()
+ .filter(_.isInstanceOf[TemporaryView])
+ .exists(_.identifier == rawDestinationIdentifier)
+
+ // If the flow is created implicitly as part of defining a view, then we
do not
+ // qualify the flow identifier and the flow destination. This is because
views are
+ // not permitted to have multipart
+ val isImplicitFlowForTempView = isImplicitFlow && flowWritesToView
+ val Seq(flowIdentifier, destinationIdentifier) =
+ Seq(rawFlowIdentifier, rawDestinationIdentifier).map { rawIdentifier =>
+ if (isImplicitFlowForTempView) {
+ rawIdentifier
+ } else {
+ GraphIdentifierManager
+ .parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier = rawIdentifier,
+ currentCatalog = Some(defaultCatalog),
+ currentDatabase = Some(defaultDatabase))
+ .identifier
+ }
+ }
+
graphElementRegistry.registerFlow(
new UnresolvedFlow(
identifier = flowIdentifier,
- destinationIdentifier = GraphIdentifierManager
- .parseTableIdentifier(name = flow.getTargetDatasetName, spark =
sessionHolder.session),
+ destinationIdentifier = destinationIdentifier,
func =
FlowAnalysis.createFlowFunctionFromLogicalPlan(transformRelationFunc(flow.getRelation)),
sqlConf = flow.getSqlConfMap.asScala.toMap,
once = false,
- queryContext = QueryContext(
- Option(graphElementRegistry.defaultCatalog),
- Option(graphElementRegistry.defaultDatabase)),
+ queryContext = QueryContext(Option(defaultCatalog),
Option(defaultDatabase)),
origin = QueryOrigin(
objectType = Option(QueryOriginType.Flow.toString),
objectName = Option(flowIdentifier.unquotedString),
language = Option(Python()))))
+ flowIdentifier
}
private def startRun(
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
index ef5da0c014ee..3f92997054e4 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
@@ -19,8 +19,10 @@ package org.apache.spark.sql.connect.pipelines
import java.util.UUID
+import scala.jdk.CollectionConverters._
+
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{DatasetType, Expression,
PipelineCommand, Relation, UnresolvedTableValuedFunction}
+import org.apache.spark.connect.proto.{DatasetType, Expression,
PipelineCommand, PipelineCommandResult, Relation, UnresolvedTableValuedFunction}
import org.apache.spark.connect.proto.PipelineCommand.{DefineDataset,
DefineFlow}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService}
@@ -486,4 +488,294 @@ class SparkDeclarativePipelinesServerSuite
assert(graphsAfter.isEmpty, "Graph should be removed after drop")
}
}
+
+ private case class DefineDatasetTestCase(
+ name: String,
+ datasetType: DatasetType,
+ datasetName: String,
+ defaultCatalog: String = "",
+ defaultDatabase: String = "",
+ expectedResolvedCatalog: String,
+ expectedResolvedNamespace: Seq[String])
+
+ private val defineDatasetDefaultTests = Seq(
+ DefineDatasetTestCase(
+ name = "TEMPORARY_VIEW",
+ datasetType = DatasetType.TEMPORARY_VIEW,
+ datasetName = "tv",
+ expectedResolvedCatalog = "",
+ expectedResolvedNamespace = Seq.empty),
+ DefineDatasetTestCase(
+ name = "TABLE",
+ datasetType = DatasetType.TABLE,
+ datasetName = "tb",
+ expectedResolvedCatalog = "spark_catalog",
+ expectedResolvedNamespace = Seq("default")),
+ DefineDatasetTestCase(
+ name = "MV",
+ datasetType = DatasetType.MATERIALIZED_VIEW,
+ datasetName = "mv",
+ expectedResolvedCatalog = "spark_catalog",
+ expectedResolvedNamespace = Seq("default"))).map(tc => tc.name ->
tc).toMap
+
+ private val defineDatasetCustomTests = Seq(
+ DefineDatasetTestCase(
+ name = "TEMPORARY_VIEW",
+ datasetType = DatasetType.TEMPORARY_VIEW,
+ datasetName = "tv",
+ defaultCatalog = "custom_catalog",
+ defaultDatabase = "custom_db",
+ expectedResolvedCatalog = "",
+ expectedResolvedNamespace = Seq.empty),
+ DefineDatasetTestCase(
+ name = "TABLE",
+ datasetType = DatasetType.TABLE,
+ datasetName = "tb",
+ defaultCatalog = "my_catalog",
+ defaultDatabase = "my_db",
+ expectedResolvedCatalog = "my_catalog",
+ expectedResolvedNamespace = Seq("my_db")),
+ DefineDatasetTestCase(
+ name = "MV",
+ datasetType = DatasetType.MATERIALIZED_VIEW,
+ datasetName = "mv",
+ defaultCatalog = "another_catalog",
+ defaultDatabase = "another_db",
+ expectedResolvedCatalog = "another_catalog",
+ expectedResolvedNamespace = Seq("another_db")))
+ .map(tc => tc.name -> tc)
+ .toMap
+
+ namedGridTest("DefineDataset returns resolved data name for default
catalog/schema")(
+ defineDatasetDefaultTests) { testCase =>
+ withRawBlockingStub { implicit stub =>
+ // Build and send the CreateDataflowGraph command with default catalog/db
+ val graphId = createDataflowGraph
+ assert(Option(graphId).isDefined)
+
+ val defineDataset = DefineDataset
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setDatasetName(testCase.datasetName)
+ .setDatasetType(testCase.datasetType)
+ val pipelineCmd = PipelineCommand
+ .newBuilder()
+ .setDefineDataset(defineDataset)
+ .build()
+ val res =
sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
+
+ assert(res !== PipelineCommandResult.getDefaultInstance)
+ assert(res.hasDefineDatasetResult)
+ val graphResult = res.getDefineDatasetResult
+ val identifier = graphResult.getResolvedIdentifier
+
+ assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
+ assert(identifier.getNamespaceList.asScala ==
testCase.expectedResolvedNamespace)
+ assert(identifier.getTableName == testCase.datasetName)
+ }
+ }
+
+ namedGridTest("DefineDataset returns resolved data name for custom
catalog/schema")(
+ defineDatasetCustomTests) { testCase =>
+ withRawBlockingStub { implicit stub =>
+ // Build and send the CreateDataflowGraph command with custom catalog/db
+ val graphId = sendPlan(
+ buildCreateDataflowGraphPlan(
+ proto.PipelineCommand.CreateDataflowGraph
+ .newBuilder()
+ .setDefaultCatalog(testCase.defaultCatalog)
+ .setDefaultDatabase(testCase.defaultDatabase)
+
.build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
+
+ assert(graphId.nonEmpty)
+
+ // Build DefineDataset with the created graphId and dataset info
+ val defineDataset = DefineDataset
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setDatasetName(testCase.datasetName)
+ .setDatasetType(testCase.datasetType)
+ val pipelineCmd = PipelineCommand
+ .newBuilder()
+ .setDefineDataset(defineDataset)
+ .build()
+
+ val res =
sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
+ assert(res !== PipelineCommandResult.getDefaultInstance)
+ assert(res.hasDefineDatasetResult)
+ val graphResult = res.getDefineDatasetResult
+ val identifier = graphResult.getResolvedIdentifier
+
+ assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
+ assert(identifier.getNamespaceList.asScala ==
testCase.expectedResolvedNamespace)
+ assert(identifier.getTableName == testCase.datasetName)
+ }
+ }
+
+ private case class DefineFlowTestCase(
+ name: String,
+ datasetType: DatasetType,
+ flowName: String,
+ defaultCatalog: String,
+ defaultDatabase: String,
+ expectedResolvedCatalog: String,
+ expectedResolvedNamespace: Seq[String])
+
+ private val defineFlowDefaultTests = Seq(
+ DefineFlowTestCase(
+ name = "MV",
+ datasetType = DatasetType.MATERIALIZED_VIEW,
+ flowName = "mv",
+ defaultCatalog = "spark_catalog",
+ defaultDatabase = "default",
+ expectedResolvedCatalog = "spark_catalog",
+ expectedResolvedNamespace = Seq("default")),
+ DefineFlowTestCase(
+ name = "TV",
+ datasetType = DatasetType.TEMPORARY_VIEW,
+ flowName = "tv",
+ defaultCatalog = "spark_catalog",
+ defaultDatabase = "default",
+ expectedResolvedCatalog = "",
+ expectedResolvedNamespace = Seq.empty)).map(tc => tc.name -> tc).toMap
+
+ private val defineFlowCustomTests = Seq(
+ DefineFlowTestCase(
+ name = "MV custom",
+ datasetType = DatasetType.MATERIALIZED_VIEW,
+ flowName = "mv",
+ defaultCatalog = "custom_catalog",
+ defaultDatabase = "custom_db",
+ expectedResolvedCatalog = "custom_catalog",
+ expectedResolvedNamespace = Seq("custom_db")),
+ DefineFlowTestCase(
+ name = "TV custom",
+ datasetType = DatasetType.TEMPORARY_VIEW,
+ flowName = "tv",
+ defaultCatalog = "custom_catalog",
+ defaultDatabase = "custom_db",
+ expectedResolvedCatalog = "",
+ expectedResolvedNamespace = Seq.empty)).map(tc => tc.name -> tc).toMap
+
+ namedGridTest("DefineFlow returns resolved data name for default
catalog/schema")(
+ defineFlowDefaultTests) { testCase =>
+ withRawBlockingStub { implicit stub =>
+ val graphId = createDataflowGraph
+ assert(graphId.nonEmpty)
+
+ // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly
first
+ if (testCase.datasetType == DatasetType.TEMPORARY_VIEW) {
+ val defineDataset = DefineDataset
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setDatasetName(testCase.flowName)
+ .setDatasetType(DatasetType.TEMPORARY_VIEW)
+
+ val defineDatasetCmd = PipelineCommand
+ .newBuilder()
+ .setDefineDataset(defineDataset)
+ .build()
+
+ val datasetRes =
+
sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult
+ assert(datasetRes.hasDefineDatasetResult)
+ }
+
+ val defineFlow = DefineFlow
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setFlowName(testCase.flowName)
+ .setTargetDatasetName(testCase.flowName)
+ .setRelation(
+ Relation
+ .newBuilder()
+ .setUnresolvedTableValuedFunction(
+ UnresolvedTableValuedFunction
+ .newBuilder()
+ .setFunctionName("range")
+ .addArguments(Expression
+ .newBuilder()
+
.setLiteral(Expression.Literal.newBuilder().setInteger(5).build())
+ .build())
+ .build())
+ .build())
+ .build()
+ val pipelineCmd = PipelineCommand
+ .newBuilder()
+ .setDefineFlow(defineFlow)
+ .build()
+ val res =
sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
+ assert(res.hasDefineFlowResult)
+ val graphResult = res.getDefineFlowResult
+ val identifier = graphResult.getResolvedIdentifier
+
+ assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
+ assert(identifier.getNamespaceList.asScala ==
testCase.expectedResolvedNamespace)
+ assert(identifier.getTableName == testCase.flowName)
+ }
+ }
+
+ namedGridTest("DefineFlow returns resolved data name for custom
catalog/schema")(
+ defineFlowCustomTests) { testCase =>
+ withRawBlockingStub { implicit stub =>
+ val graphId = sendPlan(
+ buildCreateDataflowGraphPlan(
+ proto.PipelineCommand.CreateDataflowGraph
+ .newBuilder()
+ .setDefaultCatalog(testCase.defaultCatalog)
+ .setDefaultDatabase(testCase.defaultDatabase)
+
.build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
+ assert(graphId.nonEmpty)
+
+ // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly
first
+ if (testCase.datasetType == DatasetType.TEMPORARY_VIEW) {
+ val defineDataset = DefineDataset
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setDatasetName(testCase.flowName)
+ .setDatasetType(DatasetType.TEMPORARY_VIEW)
+
+ val defineDatasetCmd = PipelineCommand
+ .newBuilder()
+ .setDefineDataset(defineDataset)
+ .build()
+
+ val datasetRes =
+
sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult
+ assert(datasetRes.hasDefineDatasetResult)
+ }
+
+ val defineFlow = DefineFlow
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setFlowName(testCase.flowName)
+ .setTargetDatasetName(testCase.flowName)
+ .setRelation(
+ Relation
+ .newBuilder()
+ .setUnresolvedTableValuedFunction(
+ UnresolvedTableValuedFunction
+ .newBuilder()
+ .setFunctionName("range")
+ .addArguments(Expression
+ .newBuilder()
+
.setLiteral(Expression.Literal.newBuilder().setInteger(5).build())
+ .build())
+ .build())
+ .build())
+ .build()
+ val pipelineCmd = PipelineCommand
+ .newBuilder()
+ .setDefineFlow(defineFlow)
+ .build()
+ val res =
sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
+ assert(res.hasDefineFlowResult)
+ val graphResult = res.getDefineFlowResult
+ val identifier = graphResult.getResolvedIdentifier
+
+ assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
+ assert(identifier.getNamespaceList.asScala ==
testCase.expectedResolvedNamespace)
+ assert(identifier.getTableName == testCase.flowName)
+ }
+ }
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
index 4494bbe0d310..b4f8315cc3fd 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
@@ -45,6 +45,10 @@ class GraphRegistrationContext(
views += viewDef
}
+ def getViews(): Seq[View] = {
+ return views.toSeq
+ }
+
def registerFlow(flowDef: UnresolvedFlow): Unit = {
flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf)
}
@@ -57,79 +61,17 @@ class GraphRegistrationContext(
errorClass = "RUN_EMPTY_PIPELINE",
messageParameters = Map.empty)
}
- val qualifiedTables = tables.toSeq.map { t =>
- t.copy(
- identifier = GraphIdentifierManager
- .parseAndQualifyTableIdentifier(
- rawTableIdentifier = t.identifier,
- currentCatalog = Some(defaultCatalog),
- currentDatabase = Some(defaultDatabase)
- )
- .identifier
- )
- }
-
- val validatedViews = views.toSeq.collect {
- case v: TemporaryView =>
- v.copy(
- identifier = GraphIdentifierManager
- .parseAndValidateTemporaryViewIdentifier(
- rawViewIdentifier = v.identifier
- )
- )
- case v: PersistedView =>
- v.copy(
- identifier = GraphIdentifierManager
- .parseAndValidatePersistedViewIdentifier(
- rawViewIdentifier = v.identifier,
- currentCatalog = Some(defaultCatalog),
- currentDatabase = Some(defaultDatabase)
- )
- )
- }
-
- val qualifiedFlows = flows.toSeq.map { f =>
- val isImplicitFlow = f.identifier == f.destinationIdentifier
- val flowWritesToView =
- validatedViews
- .filter(_.isInstanceOf[TemporaryView])
- .exists(_.identifier == f.destinationIdentifier)
-
- // If the flow is created implicitly as part of defining a view, then we
do not
- // qualify the flow identifier and the flow destination. This is because
views are
- // not permitted to have multipart
- if (isImplicitFlow && flowWritesToView) {
- f
- } else {
- f.copy(
- identifier = GraphIdentifierManager
- .parseAndQualifyFlowIdentifier(
- rawFlowIdentifier = f.identifier,
- currentCatalog = Some(defaultCatalog),
- currentDatabase = Some(defaultDatabase)
- )
- .identifier,
- destinationIdentifier = GraphIdentifierManager
- .parseAndQualifyFlowIdentifier(
- rawFlowIdentifier = f.destinationIdentifier,
- currentCatalog = Some(defaultCatalog),
- currentDatabase = Some(defaultDatabase)
- )
- .identifier
- )
- }
- }
assertNoDuplicates(
- qualifiedTables = qualifiedTables,
- validatedViews = validatedViews,
- qualifiedFlows = qualifiedFlows
+ qualifiedTables = tables.toSeq,
+ validatedViews = views.toSeq,
+ qualifiedFlows = flows.toSeq
)
new DataflowGraph(
- tables = qualifiedTables,
- views = validatedViews,
- flows = qualifiedFlows
+ tables = tables.toSeq,
+ views = views.toSeq,
+ flows = flows.toSeq
)
}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index d0a8236734d3..38bd858a688f 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -130,9 +130,16 @@ class TestGraphRegistrationContext(
): Unit = {
// scalastyle:on
val tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark)
+ val qualifiedIdentifier = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = GraphIdentifierManager
+ .parseTableIdentifier(name, spark),
+ currentCatalog = catalog.orElse(Some(defaultCatalog)),
+ currentDatabase = database.orElse(Some(defaultDatabase)))
+ .identifier
registerTable(
Table(
- identifier = GraphIdentifierManager.parseTableIdentifier(name, spark),
+ identifier = qualifiedIdentifier,
comment = comment,
specifiedSchema = specifiedSchema,
partitionCols = partitionCols,
@@ -147,8 +154,8 @@ class TestGraphRegistrationContext(
if (query.isDefined) {
registerFlow(
new UnresolvedFlow(
- identifier = tableIdentifier,
- destinationIdentifier = tableIdentifier,
+ identifier = qualifiedIdentifier,
+ destinationIdentifier = qualifiedIdentifier,
func = query.get,
queryContext = QueryContext(
currentCatalog = catalog.orElse(Some(defaultCatalog)),
@@ -193,9 +200,21 @@ class TestGraphRegistrationContext(
sqlText: Option[String] = None
): Unit = {
- val viewIdentifier = GraphIdentifierManager
+ val tempViewIdentifier = GraphIdentifierManager
.parseAndValidateTemporaryViewIdentifier(rawViewIdentifier =
TableIdentifier(name))
+ val persistedViewIdentifier = GraphIdentifierManager
+ .parseAndValidatePersistedViewIdentifier(
+ rawViewIdentifier = TableIdentifier(name),
+ currentCatalog = catalog.orElse(Some(defaultCatalog)),
+ currentDatabase = database.orElse(Some(defaultDatabase))
+ )
+
+ val viewIdentifier: TableIdentifier = viewType match {
+ case LocalTempView => tempViewIdentifier
+ case _ => persistedViewIdentifier
+ }
+
registerView(
viewType match {
case LocalTempView =>
@@ -241,10 +260,33 @@ class TestGraphRegistrationContext(
catalog: Option[String] = None,
database: Option[String] = None
): Unit = {
- val flowIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark)
- val flowDestinationIdentifier =
+ val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark)
+ val rawDestinationIdentifier =
GraphIdentifierManager.parseTableIdentifier(destinationName, spark)
+ val flowWritesToView = getViews()
+ .filter(_.isInstanceOf[TemporaryView])
+ .exists(_.identifier == rawDestinationIdentifier)
+
+ // If the flow is created implicitly as part of defining a view, then we
do not
+ // qualify the flow identifier and the flow destination. This is because
views are
+ // not permitted to have multipart
+ val isImplicitFlow = rawFlowIdentifier == rawDestinationIdentifier
+ val isImplicitFlowForTempView = isImplicitFlow && flowWritesToView
+ val Seq(flowIdentifier, flowDestinationIdentifier) =
+ Seq(rawFlowIdentifier, rawDestinationIdentifier).map { rawIdentifier =>
+ if (isImplicitFlowForTempView) {
+ rawIdentifier
+ } else {
+ GraphIdentifierManager
+ .parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier = rawIdentifier,
+ currentCatalog = catalog.orElse(Some(defaultCatalog)),
+ currentDatabase = database.orElse(Some(defaultDatabase)))
+ .identifier
+ }
+ }
+
registerFlow(
new UnresolvedFlow(
identifier = flowIdentifier,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]