This is an automated email from the ASF dual-hosted git repository. sandy pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dc687d4c83b8 [SPARK-52853][SDP] Prevent imperative PySpark methods in declarative pipelines dc687d4c83b8 is described below commit dc687d4c83b877e90c8dc03fb88f13440d4ae911 Author: Jacky Wang <jacky.w...@databricks.com> AuthorDate: Thu Jul 24 11:17:00 2025 -0700 [SPARK-52853][SDP] Prevent imperative PySpark methods in declarative pipelines ### What changes were proposed in this pull request? This PR adds a context manager `block_imperative_construct()` that prevents the execution of imperative Spark operations within declarative pipeline definitions. When these blocked methods are called, users receive clear error messages with guidance on declarative alternatives. #### Blocked Methods ##### Configuration Management - **`spark.conf.set()`** → Use pipeline spec or `spark_conf` decorator parameter ##### Catalog Management - **`spark.catalog.setCurrentCatalog()`** → Set via pipeline spec or dataset decorator `name` argument - **`spark.catalog.setCurrentDatabase()`** → Set via pipeline spec or dataset decorator `name` argument ##### Temporary View Management - **`spark.catalog.dropTempView()`** → Remove temporary view definition directly - **`spark.catalog.dropGlobalTempView()`** → Remove temporary view definition directly - **`DataFrame.createTempView()`** → Use `temporary_view` decorator - **`DataFrame.createOrReplaceTempView()`** → Use `temporary_view` decorator - **`DataFrame.createGlobalTempView()`** → Use `temporary_view` decorator - **`DataFrame.createOrReplaceGlobalTempView()`** → Use `temporary_view` decorator ##### UDF Registration - **`spark.udf.register()`** → Define and register UDFs before pipeline execution - **`spark.udf.registerJavaFunction()`** → Define and register Java UDFs before pipeline execution - **`spark.udf.registerJavaUDAF()`** → Define and register Java UDAFs before pipeline execution ### Why are the changes needed? These are imperative construct that can cause friction and unexpected behavior from within a pipeline declaration. E.g. it makes pipeline behavior sensitive to the order that Python files are imported in, which can be unpredictable. There are already existing mechanisms for setting Spark confs for pipelines: ### Does this PR introduce _any_ user-facing change? Yes, it prevents the behavior of setting spark confs imperatively in the pipeline definition file. ### How was this patch tested? Created new test suite to test that the context manager behave as expected and ran `spark-pipelines` cli manually. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51590 from JiaqiWang18/SPARK-52853-prevent-py-conf-set. Authored-by: Jacky Wang <jacky.w...@databricks.com> Signed-off-by: Sandy Ryza <sandy.r...@databricks.com> --- dev/sparktestsupport/modules.py | 1 + python/pyspark/errors/error-conditions.json | 67 ++++++ .../pyspark/pipelines/block_session_mutations.py | 135 +++++++++++ python/pyspark/pipelines/cli.py | 4 +- .../tests/test_block_session_mutations.py | 259 +++++++++++++++++++++ 5 files changed, 465 insertions(+), 1 deletion(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9822080693f8..2e4f67b78544 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1520,6 +1520,7 @@ pyspark_pipelines = Module( source_file_regexes=["python/pyspark/pipelines"], python_test_goals=[ "pyspark.pipelines.tests.test_block_connect_access", + "pyspark.pipelines.tests.test_block_session_mutations", "pyspark.pipelines.tests.test_cli", "pyspark.pipelines.tests.test_decorators", "pyspark.pipelines.tests.test_graph_element_registry", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index ffb1afbd03bd..bf54801a0189 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1007,6 +1007,73 @@ "Cannot start a remote Spark session because there is a regular Spark session already running." ] }, + "SESSION_MUTATION_IN_DECLARATIVE_PIPELINE": { + "message": [ + "Session mutation <method> is not allowed in declarative pipelines." + ], + "sub_class": { + "SET_RUNTIME_CONF": { + "message": [ + "Instead set configuration via the pipeline spec or use the 'spark_conf' argument in various decorators." + ] + }, + "SET_CURRENT_CATALOG": { + "message": [ + "Instead set catalog via the pipeline spec or the 'name' argument on the dataset decorators." + ] + }, + "SET_CURRENT_DATABASE": { + "message": [ + "Instead set database via the pipeline spec or the 'name' argument on the dataset decorators." + ] + }, + "DROP_TEMP_VIEW": { + "message": [ + "Instead remove the temporary view definition directly." + ] + }, + "DROP_GLOBAL_TEMP_VIEW": { + "message": [ + "Instead remove the temporary view definition directly." + ] + }, + "CREATE_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "CREATE_OR_REPLACE_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "CREATE_GLOBAL_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "REGISTER_UDF": { + "message": [ + "" + ] + }, + "REGISTER_JAVA_UDF": { + "message": [ + "" + ] + }, + "REGISTER_JAVA_UDAF": { + "message": [ + "" + ] + } + } + }, "SESSION_NEED_CONN_STR_OR_BUILDER": { "message": [ "Needs either connection string or channelBuilder (mutually exclusive) to create a new SparkSession." diff --git a/python/pyspark/pipelines/block_session_mutations.py b/python/pyspark/pipelines/block_session_mutations.py new file mode 100644 index 000000000000..df63d2023a4b --- /dev/null +++ b/python/pyspark/pipelines/block_session_mutations.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from contextlib import contextmanager +from typing import Generator, NoReturn, List, Callable + +from pyspark.errors import PySparkException +from pyspark.sql.connect.catalog import Catalog +from pyspark.sql.connect.conf import RuntimeConf +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.udf import UDFRegistration + +# pyspark methods that should be blocked from executing in python pipeline definition files +ERROR_CLASS = "SESSION_MUTATION_IN_DECLARATIVE_PIPELINE" +BLOCKED_METHODS: List = [ + { + "class": RuntimeConf, + "method": "set", + "error_sub_class": "SET_RUNTIME_CONF", + }, + { + "class": Catalog, + "method": "setCurrentCatalog", + "error_sub_class": "SET_CURRENT_CATALOG", + }, + { + "class": Catalog, + "method": "setCurrentDatabase", + "error_sub_class": "SET_CURRENT_DATABASE", + }, + { + "class": Catalog, + "method": "dropTempView", + "error_sub_class": "DROP_TEMP_VIEW", + }, + { + "class": Catalog, + "method": "dropGlobalTempView", + "error_sub_class": "DROP_GLOBAL_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createTempView", + "error_sub_class": "CREATE_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createOrReplaceTempView", + "error_sub_class": "CREATE_OR_REPLACE_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createGlobalTempView", + "error_sub_class": "CREATE_GLOBAL_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createOrReplaceGlobalTempView", + "error_sub_class": "CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW", + }, + { + "class": UDFRegistration, + "method": "register", + "error_sub_class": "REGISTER_UDF", + }, + { + "class": UDFRegistration, + "method": "registerJavaFunction", + "error_sub_class": "REGISTER_JAVA_UDF", + }, + { + "class": UDFRegistration, + "method": "registerJavaUDAF", + "error_sub_class": "REGISTER_JAVA_UDAF", + }, +] + + +def _create_blocked_method(error_method_name: str, error_sub_class: str) -> Callable: + def blocked_method(*args: object, **kwargs: object) -> NoReturn: + raise PySparkException( + errorClass=f"{ERROR_CLASS}.{error_sub_class}", + messageParameters={ + "method": error_method_name, + }, + ) + + return blocked_method + + +@contextmanager +def block_session_mutations() -> Generator[None, None, None]: + """ + Context manager that blocks imperative constructs found in a pipeline python definition file + See BLOCKED_METHODS above for a list + """ + # Store original methods + original_methods = {} + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + original_methods[(cls, method_name)] = getattr(cls, method_name) + + try: + # Replace methods with blocked versions + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + error_method_name = f"'{cls.__name__}.{method_name}'" + blocked_method = _create_blocked_method( + error_method_name, method_info["error_sub_class"] + ) + setattr(cls, method_name, blocked_method) + + yield + finally: + # Restore original methods + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + original_method = original_methods[(cls, method_name)] + setattr(cls, method_name, original_method) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index cbcac35cf1b3..43f9ae150f3f 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -32,6 +32,7 @@ from typing import Any, Generator, List, Mapping, Optional, Sequence from pyspark.errors import PySparkException, PySparkTypeError from pyspark.sql import SparkSession +from pyspark.pipelines.block_session_mutations import block_session_mutations from pyspark.pipelines.graph_element_registry import ( graph_element_registration_context, GraphElementRegistry, @@ -192,7 +193,8 @@ def register_definitions( assert ( module_spec.loader is not None ), f"Module spec has no loader for {file}" - module_spec.loader.exec_module(module) + with block_session_mutations(): + module_spec.loader.exec_module(module) elif file.suffix == ".sql": log_with_curr_timestamp(f"Registering SQL file {file}...") with file.open("r") as f: diff --git a/python/pyspark/pipelines/tests/test_block_session_mutations.py b/python/pyspark/pipelines/tests/test_block_session_mutations.py new file mode 100644 index 000000000000..771321d73832 --- /dev/null +++ b/python/pyspark/pipelines/tests/test_block_session_mutations.py @@ -0,0 +1,259 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.errors import PySparkException +from pyspark.sql.types import StringType +from pyspark.testing.connectutils import ( + ReusedConnectTestCase, + should_test_connect, + connect_requirement_message, +) + +from pyspark.pipelines.block_session_mutations import ( + block_session_mutations, + BLOCKED_METHODS, + ERROR_CLASS, +) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message or "Connect not available") +class BlockImperativeConfSetConnectTests(ReusedConnectTestCase): + def test_blocks_runtime_conf_set(self): + """Test that spark.conf.set() is blocked.""" + config = self.spark.conf + + test_cases = [ + ("spark.test.string", "string_value"), + ("spark.test.int", 42), + ("spark.test.bool", True), + ] + + for key, value in test_cases: + with self.subTest(key=key, value=value): + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + config.set(key, value) + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.SET_RUNTIME_CONF", + ) + self.assertIn("'RuntimeConf.set'", str(context.exception)) + + def test_blocks_catalog_set_current_catalog(self): + """Test that spark.catalog.setCurrentCatalog() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.setCurrentCatalog("test_catalog") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.SET_CURRENT_CATALOG", + ) + self.assertIn("'Catalog.setCurrentCatalog'", str(context.exception)) + + def test_blocks_catalog_set_current_database(self): + """Test that spark.catalog.setCurrentDatabase() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.setCurrentDatabase("test_db") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.SET_CURRENT_DATABASE", + ) + self.assertIn("'Catalog.setCurrentDatabase'", str(context.exception)) + + def test_blocks_catalog_drop_temp_view(self): + """Test that spark.catalog.dropTempView() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.dropTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.DROP_TEMP_VIEW", + ) + self.assertIn("'Catalog.dropTempView'", str(context.exception)) + + def test_blocks_catalog_drop_global_temp_view(self): + """Test that spark.catalog.dropGlobalTempView() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.dropGlobalTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.DROP_GLOBAL_TEMP_VIEW", + ) + self.assertIn("'Catalog.dropGlobalTempView'", str(context.exception)) + + def test_blocks_dataframe_create_temp_view(self): + """Test that DataFrame.createTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createTempView'", str(context.exception)) + + def test_blocks_dataframe_create_or_replace_temp_view(self): + """Test that DataFrame.createOrReplaceTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createOrReplaceTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_OR_REPLACE_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createOrReplaceTempView'", str(context.exception)) + + def test_blocks_dataframe_create_global_temp_view(self): + """Test that DataFrame.createGlobalTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createGlobalTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_GLOBAL_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createGlobalTempView'", str(context.exception)) + + def test_blocks_dataframe_create_or_replace_global_temp_view(self): + """Test that DataFrame.createOrReplaceGlobalTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createOrReplaceGlobalTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createOrReplaceGlobalTempView'", str(context.exception)) + + def test_blocks_udf_register(self): + """Test that spark.udf.register() is blocked.""" + udf_registry = self.spark.udf + + def test_func(x): + return x + 1 + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + udf_registry.register("test_udf", test_func, StringType()) + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.REGISTER_UDF", + ) + self.assertIn("'UDFRegistration.register'", str(context.exception)) + + def test_blocks_udf_register_java_function(self): + """Test that spark.udf.registerJavaFunction() is blocked.""" + udf_registry = self.spark.udf + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + udf_registry.registerJavaFunction( + "test_java_udf", "com.example.TestUDF", StringType() + ) + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.REGISTER_JAVA_UDF", + ) + self.assertIn("'UDFRegistration.registerJavaFunction'", str(context.exception)) + + def test_blocks_udf_register_java_udaf(self): + """Test that spark.udf.registerJavaUDAF() is blocked.""" + udf_registry = self.spark.udf + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + udf_registry.registerJavaUDAF("test_java_udaf", "com.example.TestUDAF") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.REGISTER_JAVA_UDAF", + ) + self.assertIn("'UDFRegistration.registerJavaUDAF'", str(context.exception)) + + def test_restores_original_methods_after_context(self): + """Test that all methods are properly restored after context manager exits.""" + # Store original methods + original_methods = {} + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + original_methods[(cls, method_name)] = getattr(cls, method_name) + + # Verify methods are originally set correctly + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + with self.subTest(class_method=f"{cls.__name__}.{method_name}"): + self.assertIs(getattr(cls, method_name), original_methods[(cls, method_name)]) + + # Verify methods are replaced during context + with block_session_mutations(): + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + with self.subTest(class_method=f"{cls.__name__}.{method_name}"): + self.assertIsNot( + getattr(cls, method_name), original_methods[(cls, method_name)] + ) + + # Verify methods are restored after context + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + with self.subTest(class_method=f"{cls.__name__}.{method_name}"): + self.assertIs(getattr(cls, method_name), original_methods[(cls, method_name)]) + + +if __name__ == "__main__": + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org