This is an automated email from the ASF dual-hosted git repository.
sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1e50354c4bbd [SPARK-54562] Block eager analysis / execution inside
flow function from the server side
1e50354c4bbd is described below
commit 1e50354c4bbda957b2ccfbcbdb85629fd8cf6036
Author: Yuheng Chang <[email protected]>
AuthorDate: Thu Dec 4 06:27:07 2025 -0800
[SPARK-54562] Block eager analysis / execution inside flow function from
the server side
### What changes were proposed in this pull request?
Eager analysis / execution inside the SDP flow function is not supported.
Previously, we enforce such restriction by intercepting the `AnalyzePlan` /
`ExecutePlan` RPC from the Python Spark Connect Client side, which is fragile.
It's mostly because during that time, we don't have a good way for the SC
server to know whether it is serving a requests that made inside a flow
function.
Now, we have `PipelineAnalysisContext`. If the request comes from the flow
function, SC client automatically add a `PipelineAnalysisContext` UserContext
extension and SC server can based on its existence to know whether the request
comes from flow function. Thereby, we're now able to move the eager analysis /
execution validation from the client to the server.
### Why are the changes needed?
Our previous way of performing such validation is very fragile and this PR
serves as an code improvement.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit tests in PythonPipelineSuite
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53303 from SCHJonathan/jonathan-chang_data/block-eager-analysis.
Authored-by: Yuheng Chang <[email protected]>
Signed-off-by: Sandy Ryza <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 ++
dev/sparktestsupport/modules.py | 1 -
python/pyspark/errors/error-conditions.json | 5 --
python/pyspark/pipelines/block_connect_access.py | 86 ----------------------
.../spark_connect_graph_element_registry.py | 4 +-
.../pipelines/tests/test_block_connect_access.py | 64 ----------------
.../execution/SparkConnectPlanExecution.scala | 14 +++-
.../sql/connect/planner/SparkConnectPlanner.scala | 27 ++-----
.../service/SparkConnectAnalyzeHandler.scala | 10 ++-
.../utils/PipelineAnalysisContextUtils.scala | 85 +++++++++++++++++++++
.../connect/pipelines/PythonPipelineSuite.scala | 64 ++++++++++++++++
11 files changed, 182 insertions(+), 184 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 4bd4f3cfc764..ac1cde10bdf5 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -173,6 +173,12 @@
},
"sqlState" : "42604"
},
+ "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION" : {
+ "message" : [
+ "Operations that trigger DataFrame analysis or execution are not allowed
in pipeline query functions. Move code outside of the pipeline query function."
+ ],
+ "sqlState" : "0A000"
+ },
"AVRO_CANNOT_WRITE_NULL_FIELD" : {
"message" : [
"Cannot write null value for field <name> defined as non-null Avro data
type <dataType>.",
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 473ec5cdbff1..306a3b69223f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1566,7 +1566,6 @@ pyspark_pipelines = Module(
dependencies=[pyspark_core, pyspark_sql, pyspark_connect],
source_file_regexes=["python/pyspark/pipelines"],
python_test_goals=[
- "pyspark.pipelines.tests.test_block_connect_access",
"pyspark.pipelines.tests.test_block_session_mutations",
"pyspark.pipelines.tests.test_cli",
"pyspark.pipelines.tests.test_decorators",
diff --git a/python/pyspark/errors/error-conditions.json
b/python/pyspark/errors/error-conditions.json
index 295b372cade5..c2928442971d 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -14,11 +14,6 @@
"Arrow legacy IPC format is not supported in PySpark, please unset
ARROW_PRE_0_15_IPC_FORMAT."
]
},
- "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION": {
- "message": [
- "Operations that trigger DataFrame analysis or execution are not allowed
in pipeline query functions. Move code outside of the pipeline query function."
- ]
- },
"ATTRIBUTE_NOT_CALLABLE": {
"message": [
"Attribute `<attr_name>` in provided object `<obj_name>` is not
callable."
diff --git a/python/pyspark/pipelines/block_connect_access.py
b/python/pyspark/pipelines/block_connect_access.py
deleted file mode 100644
index 696d0e39b005..000000000000
--- a/python/pyspark/pipelines/block_connect_access.py
+++ /dev/null
@@ -1,86 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-from contextlib import contextmanager
-from typing import Any, Callable, Generator
-
-from pyspark.errors import PySparkException
-from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub
-
-
-BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"]
-
-
-def _is_sql_command_request(rpc_name: str, args: tuple) -> bool:
- """
- Check if the RPC call is a spark.sql() command (ExecutePlan with
sql_command).
-
- :param rpc_name: Name of the RPC being called
- :param args: Arguments passed to the RPC
- :return: True if this is an ExecutePlan request with a sql_command
- """
- if rpc_name != "ExecutePlan" or len(args) == 0:
- return False
-
- request = args[0]
- if not hasattr(request, "plan"):
- return False
- plan = request.plan
- if not plan.HasField("command"):
- return False
- command = plan.command
- return command.HasField("sql_command")
-
-
-@contextmanager
-def block_spark_connect_execution_and_analysis() -> Generator[None, None,
None]:
- """
- A context manager that blocks execution and analysis RPCs to the Spark
Connect backend
- by intercepting method calls on SparkConnectServiceStub instances.
-
- :param error_message : Custom error message to display when communication
is blocked.
- If not provided, a default message will be used.
- """
- # Store the original __getattribute__ method
- original_getattr = getattr(SparkConnectServiceStub, "__getattribute__")
-
- # Define a new __getattribute__ method that blocks RPC calls
- def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable:
- original_method = original_getattr(self, name)
-
- def intercepted_method(*args: object, **kwargs: object) -> Any:
- # Allow all RPCs that are not AnalyzePlan or ExecutePlan
- if name not in BLOCKED_RPC_NAMES:
- return original_method(*args, **kwargs)
- # Allow spark.sql() commands (ExecutePlan with sql_command)
- elif _is_sql_command_request(name, args):
- return original_method(*args, **kwargs)
- # Block all other AnalyzePlan and ExecutePlan calls
- else:
- raise PySparkException(
- errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
- messageParameters={},
- )
-
- return intercepted_method
-
- try:
- # Apply our custom __getattribute__ method
- setattr(SparkConnectServiceStub, "__getattribute__", blocked_getattr)
- yield
- finally:
- # Restore the original __getattribute__ method
- setattr(SparkConnectServiceStub, "__getattribute__", original_getattr)
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index b8d297fced3f..ab8831790830 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -19,7 +19,6 @@ from pathlib import Path
from pyspark.errors import PySparkTypeError
from pyspark.sql import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-from pyspark.pipelines.block_connect_access import
block_spark_connect_execution_and_analysis
from pyspark.pipelines.output import (
Output,
MaterializedView,
@@ -115,8 +114,7 @@ class
SparkConnectGraphElementRegistry(GraphElementRegistry):
with add_pipeline_analysis_context(
spark=self._spark, dataflow_graph_id=self._dataflow_graph_id,
flow_name=flow.name
):
- with block_spark_connect_execution_and_analysis():
- df = flow.func()
+ df = flow.func()
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
relation_flow_details =
pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
diff --git a/python/pyspark/pipelines/tests/test_block_connect_access.py
b/python/pyspark/pipelines/tests/test_block_connect_access.py
deleted file mode 100644
index 60688f30bfb9..000000000000
--- a/python/pyspark/pipelines/tests/test_block_connect_access.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import unittest
-
-from pyspark.errors import PySparkException
-from pyspark.testing.connectutils import (
- ReusedConnectTestCase,
- should_test_connect,
- connect_requirement_message,
-)
-
-if should_test_connect:
- from pyspark.pipelines.block_connect_access import
block_spark_connect_execution_and_analysis
-
-
[email protected](not should_test_connect, connect_requirement_message)
-class BlockSparkConnectAccessTests(ReusedConnectTestCase):
- def test_create_dataframe_not_blocked(self):
- with block_spark_connect_execution_and_analysis():
- self.spark.createDataFrame([(1,)], ["id"])
-
- def test_schema_access_blocked(self):
- df = self.spark.createDataFrame([(1,)], ["id"])
-
- with block_spark_connect_execution_and_analysis():
- with self.assertRaises(PySparkException) as context:
- df.schema
- self.assertEqual(
- context.exception.getCondition(),
"ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"
- )
-
- def test_collect_blocked(self):
- df = self.spark.createDataFrame([(1,)], ["id"])
-
- with block_spark_connect_execution_and_analysis():
- with self.assertRaises(PySparkException) as context:
- df.collect()
- self.assertEqual(
- context.exception.getCondition(),
"ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"
- )
-
-
-if __name__ == "__main__":
- try:
- import xmlrunner # type: ignore
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index f5cb2696d849..240dd05c899d 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -27,7 +27,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
@@ -35,7 +35,7 @@ import
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralP
import
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_ARROW_MAX_BATCH_SIZE,
CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE}
import org.apache.spark.sql.connect.planner.{InvalidInputErrors,
SparkConnectPlanner}
import org.apache.spark.sql.connect.service.ExecuteHolder
-import org.apache.spark.sql.connect.utils.MetricGenerator
+import org.apache.spark.sql.connect.utils.{MetricGenerator,
PipelineAnalysisContextUtils}
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec,
QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.internal.SQLConf
@@ -65,6 +65,16 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
} else {
DoNotCleanup
}
+ val userContext = request.getUserContext
+
+ // if the eager execution is triggered inside pipeline flow function,
+ // block the eager execution command that is not allowed.
+ if
(PipelineAnalysisContextUtils.isUnsupportedEagerExecutionInsideFlowFunction(
+ userContext = userContext,
+ plan = request.getPlan)) {
+ throw new
AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map())
+ }
+
request.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.ROOT =>
val dataframe =
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9af2e7cb4661..7e17a935f599 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -21,12 +21,11 @@ import java.util.{HashMap, Properties, UUID}
import scala.collection.mutable
import scala.jdk.CollectionConverters._
-import scala.reflect.ClassTag
import scala.util.Try
import scala.util.control.NonFatal
import com.google.common.collect.Lists
-import com.google.protobuf.{Any => ProtoAny, ByteString, Message}
+import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
@@ -34,7 +33,7 @@ import org.apache.spark.{SparkClassNotFoundException,
SparkEnv, SparkException,
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{CheckpointCommand,
CreateResourceProfileCommand, ExecutePlanResponse, PipelineAnalysisContext,
SqlCommand, StreamingForeachFunction, StreamingQueryCommand,
StreamingQueryCommandResult, StreamingQueryInstanceId,
StreamingQueryManagerCommand, StreamingQueryManagerCommandResult,
WriteStreamOperationStart, WriteStreamOperationStartResult}
+import org.apache.spark.connect.proto.{CheckpointCommand,
CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand,
StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult,
StreamingQueryInstanceId, StreamingQueryManagerCommand,
StreamingQueryManagerCommandResult, WriteStreamOperationStart,
WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import
org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -66,7 +65,7 @@ import org.apache.spark.sql.connect.ml.MLHandler
import org.apache.spark.sql.connect.pipelines.PipelinesHandler
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder,
SparkConnectService}
-import org.apache.spark.sql.connect.utils.MetricGenerator
+import org.apache.spark.sql.connect.utils.{MetricGenerator,
PipelineAnalysisContextUtils}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator,
TypedAggregateExpression}
@@ -2942,25 +2941,11 @@ class SparkConnectPlanner(
.build())
}
- private def getExtensionList[T <: Message: ClassTag](
- extensions: mutable.Buffer[ProtoAny]): Seq[T] = {
- val cls = implicitly[ClassTag[T]].runtimeClass
- .asInstanceOf[Class[_ <: Message]]
- extensions.collect {
- case any if any.is(cls) => any.unpack(cls).asInstanceOf[T]
- }.toSeq
- }
-
private def handleSqlCommand(
command: SqlCommand,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
- val userContextExtensions =
executeHolder.request.getUserContext.getExtensionsList.asScala
- val pipelineAnalysisContextList = {
- getExtensionList[PipelineAnalysisContext](userContextExtensions)
- }
- val hasPipelineAnalysisContext = pipelineAnalysisContextList.nonEmpty
- val insidePipelineFlowFunction =
pipelineAnalysisContextList.exists(_.hasFlowName)
+ val userContext = executeHolder.request.getUserContext
// To avoid explicit handling of the result on the client, we build the
expected input
// of the relation on the server. The client has to simply forward the
result.
val result = SqlCommandResult.newBuilder()
@@ -2984,13 +2969,13 @@ class SparkConnectPlanner(
}
// Block unsupported SQL commands if the request comes from Spark
Declarative Pipelines.
- if (hasPipelineAnalysisContext) {
+ if (PipelineAnalysisContextUtils.hasPipelineAnalysisContext(userContext)) {
PipelinesHandler.blockUnsupportedSqlCommand(queryPlan =
transformRelation(relation))
}
// If the spark.sql() is called inside a pipeline flow function, we don't
need to execute
// the SQL command and defer the actual analysis and execution to the flow
function.
- if (insidePipelineFlowFunction) {
+ if
(PipelineAnalysisContextUtils.isInsidePipelineFlowFunction(userContext)) {
result.setRelation(relation)
executeHolder.eventsManager.postFinished()
responseObserver.onNext(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index cdf7013211f7..ec8d95271c76 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -23,13 +23,13 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
-import org.apache.spark.sql.connect.utils.PlanCompressionUtils
+import org.apache.spark.sql.connect.utils.{PipelineAnalysisContextUtils,
PlanCompressionUtils}
import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode,
CostMode, ExtendedMode, FormattedMode, SimpleMode}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -62,6 +62,12 @@ private[connect] class SparkConnectAnalyzeHandler(
lazy val planner = new SparkConnectPlanner(sessionHolder)
val session = sessionHolder.session
val builder = proto.AnalyzePlanResponse.newBuilder()
+ val userContext = request.getUserContext
+
+ // Pipeline has not yet supported eager analysis inside flow function.
+ if
(PipelineAnalysisContextUtils.isInsidePipelineFlowFunction(userContext)) {
+ throw new
AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map())
+ }
def transformRelation(rel: proto.Relation) =
planner.transformRelation(rel, cachePlan = true)
def transformRelationPlan(plan: proto.Plan) = {
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PipelineAnalysisContextUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PipelineAnalysisContextUtils.scala
new file mode 100644
index 000000000000..75630c45c8c5
--- /dev/null
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PipelineAnalysisContextUtils.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.utils
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
+import scala.reflect.ClassTag
+
+import com.google.protobuf.{Any => ProtoAny, Message}
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{PipelineAnalysisContext, UserContext}
+
+/**
+ * Utilities for working with PipelineAnalysisContext in the user context.
+ */
+object PipelineAnalysisContextUtils {
+
+ /** Get a list of extensions from the user context. */
+ private def getExtensionList[T <: Message: ClassTag](
+ extensions: mutable.Buffer[ProtoAny]): Seq[T] = {
+ val cls = implicitly[ClassTag[T]].runtimeClass
+ .asInstanceOf[Class[_ <: Message]]
+ extensions.collect {
+ case any if any.is(cls) => any.unpack(cls).asInstanceOf[T]
+ }.toSeq
+ }
+
+ /** Get a list of PipelineAnalysisContext extensions from the user context.
*/
+ private def getPipelineAnalysisContextList(
+ userContext: UserContext): Seq[PipelineAnalysisContext] = {
+ val userContextExtensions = userContext.getExtensionsList.asScala
+ getExtensionList[PipelineAnalysisContext](userContextExtensions)
+ }
+
+ /** Return whether the execution / analysis is inside a pipeline flow
function. */
+ def hasPipelineAnalysisContext(userContext: UserContext): Boolean = {
+ getPipelineAnalysisContextList(userContext).nonEmpty
+ }
+
+ /** Return whether the execution / analysis is inside a pipeline flow
function. */
+ def isInsidePipelineFlowFunction(userContext: UserContext): Boolean = {
+ getPipelineAnalysisContextList(userContext).exists(_.hasFlowName)
+ }
+
+ /**
+ * Return whether the eager execution is triggered inside pipeline flow
function but the eager
+ * execution type is not allowed.
+ */
+ def isUnsupportedEagerExecutionInsideFlowFunction(
+ userContext: UserContext,
+ plan: proto.Plan): Boolean = {
+ // if the eager execution is not triggered inside pipeline flow function,
+ // don't block it.
+ if (!isInsidePipelineFlowFunction(userContext)) {
+ return false
+ }
+
+ plan.getOpTypeCase match {
+ // Root plan (e.g., df.collect(), df.first()) are not allowed inside
flow function
+ case proto.Plan.OpTypeCase.ROOT => true
+ case proto.Plan.OpTypeCase.COMMAND =>
+ // only spark.sql() command is allowed inside flow function.
+ val commandAllowList = Set(proto.Command.CommandTypeCase.SQL_COMMAND)
+ !commandAllowList.contains(plan.getCommand.getCommandTypeCase)
+ // For other Plan type introduced in the future, default to block it
+ case _ => true
+ }
+ }
+}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index 98b33c3296fa..fd05b0cc357e 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -1105,6 +1105,70 @@ class PythonPipelineSuite
|""".stripMargin)
}
+ private val eagerExecutionPythonCommands = Seq(
+ "df.collect()",
+ "df.first()",
+ "df.head(0)",
+ "df.toPandas()",
+ "spark.readStream.format(\"rate\").load().writeStream" +
+ ".format(\"memory\").queryName(\"test_query_name\").start()")
+
+ gridTest("unsupported eager execution inside flow function is blocked")(
+ eagerExecutionPythonCommands) { unsupportedEagerExecutionCommand =>
+ val ex = intercept[RuntimeException] {
+ buildGraph(s"""
+ |@dp.materialized_view()
+ |def mv():
+ | df = spark.range(5)
+ | $unsupportedEagerExecutionCommand
+ | return df
+ |""".stripMargin)
+ }
+
assert(ex.getMessage.contains("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"))
+ }
+
+ gridTest("eager execution outside flow function is
allowed")(eagerExecutionPythonCommands) {
+ unsupportedEagerExecutionCommand =>
+ buildGraph(s"""
+ |df = spark.range(5)
+ |$unsupportedEagerExecutionCommand
+ |
+ |@dp.materialized_view()
+ |def mv():
+ | df = spark.range(5)
+ | return df
+ |""".stripMargin)
+ }
+
+ private val eagerAnalysisPythonCommands = Seq("df.schema", "df.isStreaming",
"df.isLocal()")
+
+ gridTest("eager analysis inside flow function is
blocked")(eagerAnalysisPythonCommands) {
+ eagerAnalysisPythonCommand =>
+ val ex = intercept[RuntimeException] {
+ buildGraph(s"""
+ |@dp.materialized_view()
+ |def mv():
+ | df = spark.range(5)
+ | $eagerAnalysisPythonCommand
+ | return df
+ |""".stripMargin)
+ }
+
assert(ex.getMessage.contains("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"))
+ }
+
+ gridTest("eager analysis outside flow function is
allowed")(eagerAnalysisPythonCommands) {
+ eagerAnalysisPythonCommand =>
+ buildGraph(s"""
+ |df = spark.range(5)
+ |$eagerAnalysisPythonCommand
+ |
+ |@dp.materialized_view()
+ |def mv():
+ | df = spark.range(5)
+ | return df
+ |""".stripMargin)
+ }
+
override protected def test(testName: String, testTags: Tag*)(testFun: =>
Any)(implicit
pos: Position): Unit = {
if (PythonTestDepsChecker.isConnectDepsAvailable) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]