This is an automated email from the ASF dual-hosted git repository.

sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e50354c4bbd [SPARK-54562] Block eager analysis / execution inside 
flow function from the server side
1e50354c4bbd is described below

commit 1e50354c4bbda957b2ccfbcbdb85629fd8cf6036
Author: Yuheng Chang <[email protected]>
AuthorDate: Thu Dec 4 06:27:07 2025 -0800

    [SPARK-54562] Block eager analysis / execution inside flow function from 
the server side
    
    ### What changes were proposed in this pull request?
    
    Eager analysis / execution inside the SDP flow function is not supported. 
Previously, we enforce such restriction by intercepting the `AnalyzePlan` / 
`ExecutePlan` RPC  from the Python Spark Connect Client side, which is fragile. 
It's mostly because during that time, we don't have a good way for the SC 
server to know whether it is serving a requests that  made inside a flow 
function.
    
    Now, we have `PipelineAnalysisContext`. If the request comes from the flow 
function, SC client automatically add a `PipelineAnalysisContext` UserContext 
extension and SC server can based on its existence to know whether the request 
comes from flow function. Thereby, we're now able to move the eager analysis / 
execution validation from the client to the server.
    
    ### Why are the changes needed?
    
    Our previous way of performing such validation is very fragile and this PR 
serves as an code improvement.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests in PythonPipelineSuite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53303 from SCHJonathan/jonathan-chang_data/block-eager-analysis.
    
    Authored-by: Yuheng Chang <[email protected]>
    Signed-off-by: Sandy Ryza <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  6 ++
 dev/sparktestsupport/modules.py                    |  1 -
 python/pyspark/errors/error-conditions.json        |  5 --
 python/pyspark/pipelines/block_connect_access.py   | 86 ----------------------
 .../spark_connect_graph_element_registry.py        |  4 +-
 .../pipelines/tests/test_block_connect_access.py   | 64 ----------------
 .../execution/SparkConnectPlanExecution.scala      | 14 +++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 27 ++-----
 .../service/SparkConnectAnalyzeHandler.scala       | 10 ++-
 .../utils/PipelineAnalysisContextUtils.scala       | 85 +++++++++++++++++++++
 .../connect/pipelines/PythonPipelineSuite.scala    | 64 ++++++++++++++++
 11 files changed, 182 insertions(+), 184 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 4bd4f3cfc764..ac1cde10bdf5 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -173,6 +173,12 @@
     },
     "sqlState" : "42604"
   },
+  "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION" : {
+    "message" : [
+      "Operations that trigger DataFrame analysis or execution are not allowed 
in pipeline query functions. Move code outside of the pipeline query function."
+    ],
+    "sqlState" : "0A000"
+  },
   "AVRO_CANNOT_WRITE_NULL_FIELD" : {
     "message" : [
       "Cannot write null value for field <name> defined as non-null Avro data 
type <dataType>.",
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 473ec5cdbff1..306a3b69223f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1566,7 +1566,6 @@ pyspark_pipelines = Module(
     dependencies=[pyspark_core, pyspark_sql, pyspark_connect],
     source_file_regexes=["python/pyspark/pipelines"],
     python_test_goals=[
-        "pyspark.pipelines.tests.test_block_connect_access",
         "pyspark.pipelines.tests.test_block_session_mutations",
         "pyspark.pipelines.tests.test_cli",
         "pyspark.pipelines.tests.test_decorators",
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 295b372cade5..c2928442971d 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -14,11 +14,6 @@
       "Arrow legacy IPC format is not supported in PySpark, please unset 
ARROW_PRE_0_15_IPC_FORMAT."
     ]
   },
-  "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION": {
-    "message": [
-      "Operations that trigger DataFrame analysis or execution are not allowed 
in pipeline query functions. Move code outside of the pipeline query function."
-    ]
-  },
   "ATTRIBUTE_NOT_CALLABLE": {
     "message": [
       "Attribute `<attr_name>` in provided object `<obj_name>` is not 
callable."
diff --git a/python/pyspark/pipelines/block_connect_access.py 
b/python/pyspark/pipelines/block_connect_access.py
deleted file mode 100644
index 696d0e39b005..000000000000
--- a/python/pyspark/pipelines/block_connect_access.py
+++ /dev/null
@@ -1,86 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-from contextlib import contextmanager
-from typing import Any, Callable, Generator
-
-from pyspark.errors import PySparkException
-from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub
-
-
-BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"]
-
-
-def _is_sql_command_request(rpc_name: str, args: tuple) -> bool:
-    """
-    Check if the RPC call is a spark.sql() command (ExecutePlan with 
sql_command).
-
-    :param rpc_name: Name of the RPC being called
-    :param args: Arguments passed to the RPC
-    :return: True if this is an ExecutePlan request with a sql_command
-    """
-    if rpc_name != "ExecutePlan" or len(args) == 0:
-        return False
-
-    request = args[0]
-    if not hasattr(request, "plan"):
-        return False
-    plan = request.plan
-    if not plan.HasField("command"):
-        return False
-    command = plan.command
-    return command.HasField("sql_command")
-
-
-@contextmanager
-def block_spark_connect_execution_and_analysis() -> Generator[None, None, 
None]:
-    """
-    A context manager that blocks execution and analysis RPCs to the Spark 
Connect backend
-    by intercepting method calls on SparkConnectServiceStub instances.
-
-    :param error_message : Custom error message to display when communication 
is blocked.
-        If not provided, a default message will be used.
-    """
-    # Store the original __getattribute__ method
-    original_getattr = getattr(SparkConnectServiceStub, "__getattribute__")
-
-    # Define a new __getattribute__ method that blocks RPC calls
-    def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable:
-        original_method = original_getattr(self, name)
-
-        def intercepted_method(*args: object, **kwargs: object) -> Any:
-            # Allow all RPCs that are not AnalyzePlan or ExecutePlan
-            if name not in BLOCKED_RPC_NAMES:
-                return original_method(*args, **kwargs)
-            # Allow spark.sql() commands (ExecutePlan with sql_command)
-            elif _is_sql_command_request(name, args):
-                return original_method(*args, **kwargs)
-            # Block all other AnalyzePlan and ExecutePlan calls
-            else:
-                raise PySparkException(
-                    errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
-                    messageParameters={},
-                )
-
-        return intercepted_method
-
-    try:
-        # Apply our custom __getattribute__ method
-        setattr(SparkConnectServiceStub, "__getattribute__", blocked_getattr)
-        yield
-    finally:
-        # Restore the original __getattribute__ method
-        setattr(SparkConnectServiceStub, "__getattribute__", original_getattr)
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py 
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index b8d297fced3f..ab8831790830 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -19,7 +19,6 @@ from pathlib import Path
 from pyspark.errors import PySparkTypeError
 from pyspark.sql import SparkSession
 from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-from pyspark.pipelines.block_connect_access import 
block_spark_connect_execution_and_analysis
 from pyspark.pipelines.output import (
     Output,
     MaterializedView,
@@ -115,8 +114,7 @@ class 
SparkConnectGraphElementRegistry(GraphElementRegistry):
         with add_pipeline_analysis_context(
             spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, 
flow_name=flow.name
         ):
-            with block_spark_connect_execution_and_analysis():
-                df = flow.func()
+            df = flow.func()
         relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
 
         relation_flow_details = 
pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
diff --git a/python/pyspark/pipelines/tests/test_block_connect_access.py 
b/python/pyspark/pipelines/tests/test_block_connect_access.py
deleted file mode 100644
index 60688f30bfb9..000000000000
--- a/python/pyspark/pipelines/tests/test_block_connect_access.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import unittest
-
-from pyspark.errors import PySparkException
-from pyspark.testing.connectutils import (
-    ReusedConnectTestCase,
-    should_test_connect,
-    connect_requirement_message,
-)
-
-if should_test_connect:
-    from pyspark.pipelines.block_connect_access import 
block_spark_connect_execution_and_analysis
-
-
[email protected](not should_test_connect, connect_requirement_message)
-class BlockSparkConnectAccessTests(ReusedConnectTestCase):
-    def test_create_dataframe_not_blocked(self):
-        with block_spark_connect_execution_and_analysis():
-            self.spark.createDataFrame([(1,)], ["id"])
-
-    def test_schema_access_blocked(self):
-        df = self.spark.createDataFrame([(1,)], ["id"])
-
-        with block_spark_connect_execution_and_analysis():
-            with self.assertRaises(PySparkException) as context:
-                df.schema
-            self.assertEqual(
-                context.exception.getCondition(), 
"ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"
-            )
-
-    def test_collect_blocked(self):
-        df = self.spark.createDataFrame([(1,)], ["id"])
-
-        with block_spark_connect_execution_and_analysis():
-            with self.assertRaises(PySparkException) as context:
-                df.collect()
-            self.assertEqual(
-                context.exception.getCondition(), 
"ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"
-            )
-
-
-if __name__ == "__main__":
-    try:
-        import xmlrunner  # type: ignore
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index f5cb2696d849..240dd05c899d 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -27,7 +27,7 @@ import io.grpc.stub.StreamObserver
 import org.apache.spark.SparkEnv
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.ExecutePlanResponse
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.classic.{DataFrame, Dataset}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
@@ -35,7 +35,7 @@ import 
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralP
 import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_ARROW_MAX_BATCH_SIZE, 
CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE}
 import org.apache.spark.sql.connect.planner.{InvalidInputErrors, 
SparkConnectPlanner}
 import org.apache.spark.sql.connect.service.ExecuteHolder
-import org.apache.spark.sql.connect.utils.MetricGenerator
+import org.apache.spark.sql.connect.utils.{MetricGenerator, 
PipelineAnalysisContextUtils}
 import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, 
QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.internal.SQLConf
@@ -65,6 +65,16 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       } else {
         DoNotCleanup
       }
+    val userContext = request.getUserContext
+
+    // if the eager execution is triggered inside pipeline flow function,
+    // block the eager execution command that is not allowed.
+    if 
(PipelineAnalysisContextUtils.isUnsupportedEagerExecutionInsideFlowFunction(
+        userContext = userContext,
+        plan = request.getPlan)) {
+      throw new 
AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map())
+    }
+
     request.getPlan.getOpTypeCase match {
       case proto.Plan.OpTypeCase.ROOT =>
         val dataframe =
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9af2e7cb4661..7e17a935f599 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -21,12 +21,11 @@ import java.util.{HashMap, Properties, UUID}
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
-import scala.reflect.ClassTag
 import scala.util.Try
 import scala.util.control.NonFatal
 
 import com.google.common.collect.Lists
-import com.google.protobuf.{Any => ProtoAny, ByteString, Message}
+import com.google.protobuf.{Any => ProtoAny, ByteString}
 import io.grpc.{Context, Status, StatusRuntimeException}
 import io.grpc.stub.StreamObserver
 
@@ -34,7 +33,7 @@ import org.apache.spark.{SparkClassNotFoundException, 
SparkEnv, SparkException,
 import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
 import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{CheckpointCommand, 
CreateResourceProfileCommand, ExecutePlanResponse, PipelineAnalysisContext, 
SqlCommand, StreamingForeachFunction, StreamingQueryCommand, 
StreamingQueryCommandResult, StreamingQueryInstanceId, 
StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, 
WriteStreamOperationStart, WriteStreamOperationStartResult}
+import org.apache.spark.connect.proto.{CheckpointCommand, 
CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, 
StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, 
StreamingQueryInstanceId, StreamingQueryManagerCommand, 
StreamingQueryManagerCommandResult, WriteStreamOperationStart, 
WriteStreamOperationStartResult}
 import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
 import org.apache.spark.connect.proto.Parse.ParseFormat
 import 
org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -66,7 +65,7 @@ import org.apache.spark.sql.connect.ml.MLHandler
 import org.apache.spark.sql.connect.pipelines.PipelinesHandler
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
 import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, 
SparkConnectService}
-import org.apache.spark.sql.connect.utils.MetricGenerator
+import org.apache.spark.sql.connect.utils.{MetricGenerator, 
PipelineAnalysisContextUtils}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, 
TypedAggregateExpression}
@@ -2942,25 +2941,11 @@ class SparkConnectPlanner(
         .build())
   }
 
-  private def getExtensionList[T <: Message: ClassTag](
-      extensions: mutable.Buffer[ProtoAny]): Seq[T] = {
-    val cls = implicitly[ClassTag[T]].runtimeClass
-      .asInstanceOf[Class[_ <: Message]]
-    extensions.collect {
-      case any if any.is(cls) => any.unpack(cls).asInstanceOf[T]
-    }.toSeq
-  }
-
   private def handleSqlCommand(
       command: SqlCommand,
       responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
     val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
-    val userContextExtensions = 
executeHolder.request.getUserContext.getExtensionsList.asScala
-    val pipelineAnalysisContextList = {
-      getExtensionList[PipelineAnalysisContext](userContextExtensions)
-    }
-    val hasPipelineAnalysisContext = pipelineAnalysisContextList.nonEmpty
-    val insidePipelineFlowFunction = 
pipelineAnalysisContextList.exists(_.hasFlowName)
+    val userContext = executeHolder.request.getUserContext
     // To avoid explicit handling of the result on the client, we build the 
expected input
     // of the relation on the server. The client has to simply forward the 
result.
     val result = SqlCommandResult.newBuilder()
@@ -2984,13 +2969,13 @@ class SparkConnectPlanner(
     }
 
     // Block unsupported SQL commands if the request comes from Spark 
Declarative Pipelines.
-    if (hasPipelineAnalysisContext) {
+    if (PipelineAnalysisContextUtils.hasPipelineAnalysisContext(userContext)) {
       PipelinesHandler.blockUnsupportedSqlCommand(queryPlan = 
transformRelation(relation))
     }
 
     // If the spark.sql() is called inside a pipeline flow function, we don't 
need to execute
     // the SQL command and defer the actual analysis and execution to the flow 
function.
-    if (insidePipelineFlowFunction) {
+    if 
(PipelineAnalysisContextUtils.isInsidePipelineFlowFunction(userContext)) {
       result.setRelation(relation)
       executeHolder.eventsManager.postFinished()
       responseObserver.onNext(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index cdf7013211f7..ec8d95271c76 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -23,13 +23,13 @@ import io.grpc.stub.StreamObserver
 
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.classic.{DataFrame, Dataset}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
InvalidPlanInput, StorageLevelProtoConverter}
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
-import org.apache.spark.sql.connect.utils.PlanCompressionUtils
+import org.apache.spark.sql.connect.utils.{PipelineAnalysisContextUtils, 
PlanCompressionUtils}
 import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, 
CostMode, ExtendedMode, FormattedMode, SimpleMode}
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.util.ArrayImplicits._
@@ -62,6 +62,12 @@ private[connect] class SparkConnectAnalyzeHandler(
     lazy val planner = new SparkConnectPlanner(sessionHolder)
     val session = sessionHolder.session
     val builder = proto.AnalyzePlanResponse.newBuilder()
+    val userContext = request.getUserContext
+
+    // Pipeline has not yet supported eager analysis inside flow function.
+    if 
(PipelineAnalysisContextUtils.isInsidePipelineFlowFunction(userContext)) {
+      throw new 
AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map())
+    }
 
     def transformRelation(rel: proto.Relation) = 
planner.transformRelation(rel, cachePlan = true)
     def transformRelationPlan(plan: proto.Plan) = {
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PipelineAnalysisContextUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PipelineAnalysisContextUtils.scala
new file mode 100644
index 000000000000..75630c45c8c5
--- /dev/null
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PipelineAnalysisContextUtils.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.utils
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
+import scala.reflect.ClassTag
+
+import com.google.protobuf.{Any => ProtoAny, Message}
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{PipelineAnalysisContext, UserContext}
+
+/**
+ * Utilities for working with PipelineAnalysisContext in the user context.
+ */
+object PipelineAnalysisContextUtils {
+
+  /** Get a list of extensions from the user context. */
+  private def getExtensionList[T <: Message: ClassTag](
+      extensions: mutable.Buffer[ProtoAny]): Seq[T] = {
+    val cls = implicitly[ClassTag[T]].runtimeClass
+      .asInstanceOf[Class[_ <: Message]]
+    extensions.collect {
+      case any if any.is(cls) => any.unpack(cls).asInstanceOf[T]
+    }.toSeq
+  }
+
+  /** Get a list of PipelineAnalysisContext extensions from the user context. 
*/
+  private def getPipelineAnalysisContextList(
+      userContext: UserContext): Seq[PipelineAnalysisContext] = {
+    val userContextExtensions = userContext.getExtensionsList.asScala
+    getExtensionList[PipelineAnalysisContext](userContextExtensions)
+  }
+
+  /** Return whether the execution / analysis is inside a pipeline flow 
function. */
+  def hasPipelineAnalysisContext(userContext: UserContext): Boolean = {
+    getPipelineAnalysisContextList(userContext).nonEmpty
+  }
+
+  /** Return whether the execution / analysis is inside a pipeline flow 
function. */
+  def isInsidePipelineFlowFunction(userContext: UserContext): Boolean = {
+    getPipelineAnalysisContextList(userContext).exists(_.hasFlowName)
+  }
+
+  /**
+   * Return whether the eager execution is triggered inside pipeline flow 
function but the eager
+   * execution type is not allowed.
+   */
+  def isUnsupportedEagerExecutionInsideFlowFunction(
+      userContext: UserContext,
+      plan: proto.Plan): Boolean = {
+    // if the eager execution is not triggered inside pipeline flow function,
+    // don't block it.
+    if (!isInsidePipelineFlowFunction(userContext)) {
+      return false
+    }
+
+    plan.getOpTypeCase match {
+      // Root plan (e.g., df.collect(), df.first()) are not allowed inside 
flow function
+      case proto.Plan.OpTypeCase.ROOT => true
+      case proto.Plan.OpTypeCase.COMMAND =>
+        // only spark.sql() command is allowed inside flow function.
+        val commandAllowList = Set(proto.Command.CommandTypeCase.SQL_COMMAND)
+        !commandAllowList.contains(plan.getCommand.getCommandTypeCase)
+      // For other Plan type introduced in the future, default to block it
+      case _ => true
+    }
+  }
+}
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index 98b33c3296fa..fd05b0cc357e 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -1105,6 +1105,70 @@ class PythonPipelineSuite
         |""".stripMargin)
   }
 
+  private val eagerExecutionPythonCommands = Seq(
+    "df.collect()",
+    "df.first()",
+    "df.head(0)",
+    "df.toPandas()",
+    "spark.readStream.format(\"rate\").load().writeStream" +
+      ".format(\"memory\").queryName(\"test_query_name\").start()")
+
+  gridTest("unsupported eager execution inside flow function is blocked")(
+    eagerExecutionPythonCommands) { unsupportedEagerExecutionCommand =>
+    val ex = intercept[RuntimeException] {
+      buildGraph(s"""
+        |@dp.materialized_view()
+        |def mv():
+        |  df = spark.range(5)
+        |  $unsupportedEagerExecutionCommand
+        |  return df
+        |""".stripMargin)
+    }
+    
assert(ex.getMessage.contains("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"))
+  }
+
+  gridTest("eager execution outside flow function is 
allowed")(eagerExecutionPythonCommands) {
+    unsupportedEagerExecutionCommand =>
+      buildGraph(s"""
+      |df = spark.range(5)
+      |$unsupportedEagerExecutionCommand
+      |
+      |@dp.materialized_view()
+      |def mv():
+      |  df = spark.range(5)
+      |  return df
+      |""".stripMargin)
+  }
+
+  private val eagerAnalysisPythonCommands = Seq("df.schema", "df.isStreaming", 
"df.isLocal()")
+
+  gridTest("eager analysis inside flow function is 
blocked")(eagerAnalysisPythonCommands) {
+    eagerAnalysisPythonCommand =>
+      val ex = intercept[RuntimeException] {
+        buildGraph(s"""
+          |@dp.materialized_view()
+          |def mv():
+          |  df = spark.range(5)
+          |  $eagerAnalysisPythonCommand
+          |  return df
+          |""".stripMargin)
+      }
+      
assert(ex.getMessage.contains("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION"))
+  }
+
+  gridTest("eager analysis outside flow function is 
allowed")(eagerAnalysisPythonCommands) {
+    eagerAnalysisPythonCommand =>
+      buildGraph(s"""
+        |df = spark.range(5)
+        |$eagerAnalysisPythonCommand
+        |
+        |@dp.materialized_view()
+        |def mv():
+        |  df = spark.range(5)
+        |  return df
+        |""".stripMargin)
+  }
+
   override protected def test(testName: String, testTags: Tag*)(testFun: => 
Any)(implicit
       pos: Position): Unit = {
     if (PythonTestDepsChecker.isConnectDepsAvailable) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to