This is an automated email from the ASF dual-hosted git repository.

sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dc687d4c83b8 [SPARK-52853][SDP] Prevent imperative PySpark methods in 
declarative pipelines
dc687d4c83b8 is described below

commit dc687d4c83b877e90c8dc03fb88f13440d4ae911
Author: Jacky Wang <jacky.w...@databricks.com>
AuthorDate: Thu Jul 24 11:17:00 2025 -0700

    [SPARK-52853][SDP] Prevent imperative PySpark methods in declarative 
pipelines
    
    ### What changes were proposed in this pull request?
    
    This PR adds a context manager `block_imperative_construct()` that prevents 
the execution of imperative Spark operations within declarative pipeline 
definitions. When these blocked methods are called, users receive clear error 
messages with guidance on declarative alternatives.
    
    #### Blocked Methods
    
    ##### Configuration Management
    - **`spark.conf.set()`** → Use pipeline spec or `spark_conf` decorator 
parameter
    
    ##### Catalog Management
    - **`spark.catalog.setCurrentCatalog()`** → Set via pipeline spec or 
dataset decorator `name` argument
    - **`spark.catalog.setCurrentDatabase()`** → Set via pipeline spec or 
dataset decorator `name` argument
    
    ##### Temporary View Management
    - **`spark.catalog.dropTempView()`** → Remove temporary view definition 
directly
    - **`spark.catalog.dropGlobalTempView()`** → Remove temporary view 
definition directly
    - **`DataFrame.createTempView()`** → Use `temporary_view` decorator
    - **`DataFrame.createOrReplaceTempView()`** → Use `temporary_view` decorator
    - **`DataFrame.createGlobalTempView()`** → Use `temporary_view` decorator
    - **`DataFrame.createOrReplaceGlobalTempView()`** → Use `temporary_view` 
decorator
    
    ##### UDF Registration
    - **`spark.udf.register()`** → Define and register UDFs before pipeline 
execution
    - **`spark.udf.registerJavaFunction()`** → Define and register Java UDFs 
before pipeline execution
    - **`spark.udf.registerJavaUDAF()`** → Define and register Java UDAFs 
before pipeline execution
    
    ### Why are the changes needed?
    
    These are imperative construct that can cause friction and unexpected 
behavior from within a pipeline declaration. E.g. it makes pipeline behavior 
sensitive to the order that Python files are imported in, which can be 
unpredictable. There are already existing mechanisms for setting Spark confs 
for pipelines:
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it prevents the behavior of setting spark confs imperatively in the 
pipeline definition file.
    
    ### How was this patch tested?
    
    Created new test suite to test that the context manager behave as expected 
and ran `spark-pipelines` cli manually.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51590 from JiaqiWang18/SPARK-52853-prevent-py-conf-set.
    
    Authored-by: Jacky Wang <jacky.w...@databricks.com>
    Signed-off-by: Sandy Ryza <sandy.r...@databricks.com>
---
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/errors/error-conditions.json        |  67 ++++++
 .../pyspark/pipelines/block_session_mutations.py   | 135 +++++++++++
 python/pyspark/pipelines/cli.py                    |   4 +-
 .../tests/test_block_session_mutations.py          | 259 +++++++++++++++++++++
 5 files changed, 465 insertions(+), 1 deletion(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 9822080693f8..2e4f67b78544 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1520,6 +1520,7 @@ pyspark_pipelines = Module(
     source_file_regexes=["python/pyspark/pipelines"],
     python_test_goals=[
         "pyspark.pipelines.tests.test_block_connect_access",
+        "pyspark.pipelines.tests.test_block_session_mutations",
         "pyspark.pipelines.tests.test_cli",
         "pyspark.pipelines.tests.test_decorators",
         "pyspark.pipelines.tests.test_graph_element_registry",
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index ffb1afbd03bd..bf54801a0189 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -1007,6 +1007,73 @@
       "Cannot start a remote Spark session because there is a regular Spark 
session already running."
     ]
   },
+  "SESSION_MUTATION_IN_DECLARATIVE_PIPELINE": {
+    "message": [
+      "Session mutation <method> is not allowed in declarative pipelines."
+    ],
+    "sub_class": {
+      "SET_RUNTIME_CONF": {
+        "message": [
+          "Instead set configuration via the pipeline spec or use the 
'spark_conf' argument in various decorators."
+        ]
+      },
+      "SET_CURRENT_CATALOG": {
+        "message": [
+          "Instead set catalog via the pipeline spec or the 'name' argument on 
the dataset decorators."
+        ]
+      },
+      "SET_CURRENT_DATABASE": {
+        "message": [
+          "Instead set database via the pipeline spec or the 'name' argument 
on the dataset decorators."
+        ]
+      },
+      "DROP_TEMP_VIEW": {
+        "message": [
+          "Instead remove the temporary view definition directly."
+        ]
+      },
+      "DROP_GLOBAL_TEMP_VIEW": {
+        "message": [
+          "Instead remove the temporary view definition directly."
+        ]
+      },
+      "CREATE_TEMP_VIEW": {
+        "message": [
+          "Instead use the @temporary_view decorator to define temporary 
views."
+        ]
+      },
+      "CREATE_OR_REPLACE_TEMP_VIEW": {
+        "message": [
+          "Instead use the @temporary_view decorator to define temporary 
views."
+        ]
+      },
+      "CREATE_GLOBAL_TEMP_VIEW": {
+        "message": [
+          "Instead use the @temporary_view decorator to define temporary 
views."
+        ]
+      },
+      "CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW": {
+        "message": [
+          "Instead use the @temporary_view decorator to define temporary 
views."
+        ]
+      },
+      "REGISTER_UDF": {
+        "message": [
+          ""
+        ]
+      },
+      "REGISTER_JAVA_UDF": {
+        "message": [
+          ""
+        ]
+      },
+      "REGISTER_JAVA_UDAF": {
+        "message": [
+          ""
+        ]
+      }
+    }
+  },
   "SESSION_NEED_CONN_STR_OR_BUILDER": {
     "message": [
       "Needs either connection string or channelBuilder (mutually exclusive) 
to create a new SparkSession."
diff --git a/python/pyspark/pipelines/block_session_mutations.py 
b/python/pyspark/pipelines/block_session_mutations.py
new file mode 100644
index 000000000000..df63d2023a4b
--- /dev/null
+++ b/python/pyspark/pipelines/block_session_mutations.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from contextlib import contextmanager
+from typing import Generator, NoReturn, List, Callable
+
+from pyspark.errors import PySparkException
+from pyspark.sql.connect.catalog import Catalog
+from pyspark.sql.connect.conf import RuntimeConf
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.connect.udf import UDFRegistration
+
+# pyspark methods that should be blocked from executing in python pipeline 
definition files
+ERROR_CLASS = "SESSION_MUTATION_IN_DECLARATIVE_PIPELINE"
+BLOCKED_METHODS: List = [
+    {
+        "class": RuntimeConf,
+        "method": "set",
+        "error_sub_class": "SET_RUNTIME_CONF",
+    },
+    {
+        "class": Catalog,
+        "method": "setCurrentCatalog",
+        "error_sub_class": "SET_CURRENT_CATALOG",
+    },
+    {
+        "class": Catalog,
+        "method": "setCurrentDatabase",
+        "error_sub_class": "SET_CURRENT_DATABASE",
+    },
+    {
+        "class": Catalog,
+        "method": "dropTempView",
+        "error_sub_class": "DROP_TEMP_VIEW",
+    },
+    {
+        "class": Catalog,
+        "method": "dropGlobalTempView",
+        "error_sub_class": "DROP_GLOBAL_TEMP_VIEW",
+    },
+    {
+        "class": DataFrame,
+        "method": "createTempView",
+        "error_sub_class": "CREATE_TEMP_VIEW",
+    },
+    {
+        "class": DataFrame,
+        "method": "createOrReplaceTempView",
+        "error_sub_class": "CREATE_OR_REPLACE_TEMP_VIEW",
+    },
+    {
+        "class": DataFrame,
+        "method": "createGlobalTempView",
+        "error_sub_class": "CREATE_GLOBAL_TEMP_VIEW",
+    },
+    {
+        "class": DataFrame,
+        "method": "createOrReplaceGlobalTempView",
+        "error_sub_class": "CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW",
+    },
+    {
+        "class": UDFRegistration,
+        "method": "register",
+        "error_sub_class": "REGISTER_UDF",
+    },
+    {
+        "class": UDFRegistration,
+        "method": "registerJavaFunction",
+        "error_sub_class": "REGISTER_JAVA_UDF",
+    },
+    {
+        "class": UDFRegistration,
+        "method": "registerJavaUDAF",
+        "error_sub_class": "REGISTER_JAVA_UDAF",
+    },
+]
+
+
+def _create_blocked_method(error_method_name: str, error_sub_class: str) -> 
Callable:
+    def blocked_method(*args: object, **kwargs: object) -> NoReturn:
+        raise PySparkException(
+            errorClass=f"{ERROR_CLASS}.{error_sub_class}",
+            messageParameters={
+                "method": error_method_name,
+            },
+        )
+
+    return blocked_method
+
+
+@contextmanager
+def block_session_mutations() -> Generator[None, None, None]:
+    """
+    Context manager that blocks imperative constructs found in a pipeline 
python definition file
+    See BLOCKED_METHODS above for a list
+    """
+    # Store original methods
+    original_methods = {}
+    for method_info in BLOCKED_METHODS:
+        cls = method_info["class"]
+        method_name = method_info["method"]
+        original_methods[(cls, method_name)] = getattr(cls, method_name)
+
+    try:
+        # Replace methods with blocked versions
+        for method_info in BLOCKED_METHODS:
+            cls = method_info["class"]
+            method_name = method_info["method"]
+            error_method_name = f"'{cls.__name__}.{method_name}'"
+            blocked_method = _create_blocked_method(
+                error_method_name, method_info["error_sub_class"]
+            )
+            setattr(cls, method_name, blocked_method)
+
+        yield
+    finally:
+        # Restore original methods
+        for method_info in BLOCKED_METHODS:
+            cls = method_info["class"]
+            method_name = method_info["method"]
+            original_method = original_methods[(cls, method_name)]
+            setattr(cls, method_name, original_method)
diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index cbcac35cf1b3..43f9ae150f3f 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -32,6 +32,7 @@ from typing import Any, Generator, List, Mapping, Optional, 
Sequence
 
 from pyspark.errors import PySparkException, PySparkTypeError
 from pyspark.sql import SparkSession
+from pyspark.pipelines.block_session_mutations import block_session_mutations
 from pyspark.pipelines.graph_element_registry import (
     graph_element_registration_context,
     GraphElementRegistry,
@@ -192,7 +193,8 @@ def register_definitions(
                         assert (
                             module_spec.loader is not None
                         ), f"Module spec has no loader for {file}"
-                        module_spec.loader.exec_module(module)
+                        with block_session_mutations():
+                            module_spec.loader.exec_module(module)
                     elif file.suffix == ".sql":
                         log_with_curr_timestamp(f"Registering SQL file 
{file}...")
                         with file.open("r") as f:
diff --git a/python/pyspark/pipelines/tests/test_block_session_mutations.py 
b/python/pyspark/pipelines/tests/test_block_session_mutations.py
new file mode 100644
index 000000000000..771321d73832
--- /dev/null
+++ b/python/pyspark/pipelines/tests/test_block_session_mutations.py
@@ -0,0 +1,259 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.errors import PySparkException
+from pyspark.sql.types import StringType
+from pyspark.testing.connectutils import (
+    ReusedConnectTestCase,
+    should_test_connect,
+    connect_requirement_message,
+)
+
+from pyspark.pipelines.block_session_mutations import (
+    block_session_mutations,
+    BLOCKED_METHODS,
+    ERROR_CLASS,
+)
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message or 
"Connect not available")
+class BlockImperativeConfSetConnectTests(ReusedConnectTestCase):
+    def test_blocks_runtime_conf_set(self):
+        """Test that spark.conf.set() is blocked."""
+        config = self.spark.conf
+
+        test_cases = [
+            ("spark.test.string", "string_value"),
+            ("spark.test.int", 42),
+            ("spark.test.bool", True),
+        ]
+
+        for key, value in test_cases:
+            with self.subTest(key=key, value=value):
+                with block_session_mutations():
+                    with self.assertRaises(PySparkException) as context:
+                        config.set(key, value)
+
+                    self.assertEqual(
+                        context.exception.getCondition(),
+                        f"{ERROR_CLASS}.SET_RUNTIME_CONF",
+                    )
+                    self.assertIn("'RuntimeConf.set'", str(context.exception))
+
+    def test_blocks_catalog_set_current_catalog(self):
+        """Test that spark.catalog.setCurrentCatalog() is blocked."""
+        catalog = self.spark.catalog
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                catalog.setCurrentCatalog("test_catalog")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.SET_CURRENT_CATALOG",
+            )
+            self.assertIn("'Catalog.setCurrentCatalog'", 
str(context.exception))
+
+    def test_blocks_catalog_set_current_database(self):
+        """Test that spark.catalog.setCurrentDatabase() is blocked."""
+        catalog = self.spark.catalog
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                catalog.setCurrentDatabase("test_db")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.SET_CURRENT_DATABASE",
+            )
+            self.assertIn("'Catalog.setCurrentDatabase'", 
str(context.exception))
+
+    def test_blocks_catalog_drop_temp_view(self):
+        """Test that spark.catalog.dropTempView() is blocked."""
+        catalog = self.spark.catalog
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                catalog.dropTempView("test_view")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.DROP_TEMP_VIEW",
+            )
+            self.assertIn("'Catalog.dropTempView'", str(context.exception))
+
+    def test_blocks_catalog_drop_global_temp_view(self):
+        """Test that spark.catalog.dropGlobalTempView() is blocked."""
+        catalog = self.spark.catalog
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                catalog.dropGlobalTempView("test_view")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.DROP_GLOBAL_TEMP_VIEW",
+            )
+            self.assertIn("'Catalog.dropGlobalTempView'", 
str(context.exception))
+
+    def test_blocks_dataframe_create_temp_view(self):
+        """Test that DataFrame.createTempView() is blocked."""
+        df = self.spark.range(1)
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                df.createTempView("test_view")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.CREATE_TEMP_VIEW",
+            )
+            self.assertIn("'DataFrame.createTempView'", str(context.exception))
+
+    def test_blocks_dataframe_create_or_replace_temp_view(self):
+        """Test that DataFrame.createOrReplaceTempView() is blocked."""
+        df = self.spark.range(1)
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                df.createOrReplaceTempView("test_view")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.CREATE_OR_REPLACE_TEMP_VIEW",
+            )
+            self.assertIn("'DataFrame.createOrReplaceTempView'", 
str(context.exception))
+
+    def test_blocks_dataframe_create_global_temp_view(self):
+        """Test that DataFrame.createGlobalTempView() is blocked."""
+        df = self.spark.range(1)
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                df.createGlobalTempView("test_view")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.CREATE_GLOBAL_TEMP_VIEW",
+            )
+            self.assertIn("'DataFrame.createGlobalTempView'", 
str(context.exception))
+
+    def test_blocks_dataframe_create_or_replace_global_temp_view(self):
+        """Test that DataFrame.createOrReplaceGlobalTempView() is blocked."""
+        df = self.spark.range(1)
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                df.createOrReplaceGlobalTempView("test_view")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW",
+            )
+            self.assertIn("'DataFrame.createOrReplaceGlobalTempView'", 
str(context.exception))
+
+    def test_blocks_udf_register(self):
+        """Test that spark.udf.register() is blocked."""
+        udf_registry = self.spark.udf
+
+        def test_func(x):
+            return x + 1
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                udf_registry.register("test_udf", test_func, StringType())
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.REGISTER_UDF",
+            )
+            self.assertIn("'UDFRegistration.register'", str(context.exception))
+
+    def test_blocks_udf_register_java_function(self):
+        """Test that spark.udf.registerJavaFunction() is blocked."""
+        udf_registry = self.spark.udf
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                udf_registry.registerJavaFunction(
+                    "test_java_udf", "com.example.TestUDF", StringType()
+                )
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.REGISTER_JAVA_UDF",
+            )
+            self.assertIn("'UDFRegistration.registerJavaFunction'", 
str(context.exception))
+
+    def test_blocks_udf_register_java_udaf(self):
+        """Test that spark.udf.registerJavaUDAF() is blocked."""
+        udf_registry = self.spark.udf
+
+        with block_session_mutations():
+            with self.assertRaises(PySparkException) as context:
+                udf_registry.registerJavaUDAF("test_java_udaf", 
"com.example.TestUDAF")
+
+            self.assertEqual(
+                context.exception.getCondition(),
+                f"{ERROR_CLASS}.REGISTER_JAVA_UDAF",
+            )
+            self.assertIn("'UDFRegistration.registerJavaUDAF'", 
str(context.exception))
+
+    def test_restores_original_methods_after_context(self):
+        """Test that all methods are properly restored after context manager 
exits."""
+        # Store original methods
+        original_methods = {}
+        for method_info in BLOCKED_METHODS:
+            cls = method_info["class"]
+            method_name = method_info["method"]
+            original_methods[(cls, method_name)] = getattr(cls, method_name)
+
+        # Verify methods are originally set correctly
+        for method_info in BLOCKED_METHODS:
+            cls = method_info["class"]
+            method_name = method_info["method"]
+            with self.subTest(class_method=f"{cls.__name__}.{method_name}"):
+                self.assertIs(getattr(cls, method_name), 
original_methods[(cls, method_name)])
+
+        # Verify methods are replaced during context
+        with block_session_mutations():
+            for method_info in BLOCKED_METHODS:
+                cls = method_info["class"]
+                method_name = method_info["method"]
+                with 
self.subTest(class_method=f"{cls.__name__}.{method_name}"):
+                    self.assertIsNot(
+                        getattr(cls, method_name), original_methods[(cls, 
method_name)]
+                    )
+
+        # Verify methods are restored after context
+        for method_info in BLOCKED_METHODS:
+            cls = method_info["class"]
+            method_name = method_info["method"]
+            with self.subTest(class_method=f"{cls.__name__}.{method_name}"):
+                self.assertIs(getattr(cls, method_name), 
original_methods[(cls, method_name)])
+
+
+if __name__ == "__main__":
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to