This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 16ff24e6efdd [SPARK-56572][SDP] Inject Spark session into Python files
16ff24e6efdd is described below
commit 16ff24e6efdd3f1786d0f775ca42f30b48a7e20b
Author: Andreas Neumann <[email protected]>
AuthorDate: Wed May 20 15:40:56 2026 -0700
[SPARK-56572][SDP] Inject Spark session into Python files
### What changes were proposed in this pull request?
In Declarative Pipelines, all Python scripts run as separate modules but
have to share the same Spark session. This session is created by the framework,
and every script needs to bring it into its own scope by declaring
```
spark = SparkSession.active()
```
This bears the risk that users might create their own Spark session
instead, and that could break dependencies between different Python scripts of
the same pipeline. It would be better to inject this session directly into the
module, so user code does not need to worry about obtaining it.
Also, change `spark-pipelines init` to omit that line from the generated
sample code.
### Why are the changes needed?
It improves the experience, and avoids hard-to-debug errors if users
initialize the spark session in a different way.
### Does this PR introduce _any_ user-facing change?
- It allows users to refer to the Spark session in their Python scripts
without first assigning it.
- Previously, users had to assign it explicitly.
- This will not break existing scripts that assign it explicitly.
### How was this patch tested?
- Existing Tests in org/apache/spark/sql/pipelines/utils/APITest.scala are
modified to _not_ assign `spark`.
- A new test is added to the same suite to ensure that it does not break
existing behavior.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Sonnet 4.6
Closes #55493 from anew/inject-spark-session.
Lead-authored-by: Andreas Neumann <[email protected]>
Co-authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 441fccc853042d4bc610c25ff7616e4753856940)
Signed-off-by: Gengliang Wang <[email protected]>
---
docs/declarative-pipelines-programming-guide.md | 27 +++++++
python/pyspark/pipelines/cli.py | 1 +
python/pyspark/pipelines/init_cli.py | 4 +-
.../apache/spark/sql/pipelines/utils/APITest.scala | 86 ++++++++++++++--------
4 files changed, 85 insertions(+), 33 deletions(-)
diff --git a/docs/declarative-pipelines-programming-guide.md
b/docs/declarative-pipelines-programming-guide.md
index c5d18a7cb71b..e1c2c078212a 100644
--- a/docs/declarative-pipelines-programming-guide.md
+++ b/docs/declarative-pipelines-programming-guide.md
@@ -180,6 +180,33 @@ Your pipelines implemented with the Python API must import
this module. It's rec
from pyspark import pipelines as dp
```
+### The Spark Session in Python Pipelines
+
+In Spark 4.1, every pipeline file had to declare `spark =
SparkSession.active()` explicitly. Starting in Spark 4.2, the framework injects
spark into each pipeline file's module namespace, so the explicit assignment is
no longer required.
+
+```python
+from pyspark import pipelines as dp
+
[email protected]_view
+def my_view():
+ return spark.range(10)
+```
+
+Pipeline files that still include `spark = SparkSession.active()` continue to
work correctly. However, if you do assign the session explicitly,
`SparkSession.active()` is the only supported way to do so. For example,
`SparkSession.builder.config(...).getOrCreate()` mutates session config, which
is blocked in SDP.
+
+Note that without the explicit assignment, many tools and editors may consider
`spark` and undefined name. To address that, you can add `spark: SparkSession`
at module scope. SDP will still inject the actual session before the module
runs, so this only documents the type for static analysis.
+
+```python
+from pyspark import pipelines as dp
+from pyspark.sql import SparkSession
+
+spark: SparkSession
+
[email protected]_view
+def my_view():
+ return spark.range(10)
+```
+
### Creating a Materialized View in Python
The `@dp.materialized_view` decorator tells SDP to create a materialized view
based on the results of a function that performs a batch read:
diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index 986e9828c9a2..bc0d718f8a75 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -253,6 +253,7 @@ def register_definitions(
assert module_spec.loader is not None, (
f"Module spec has no loader for {file}"
)
+ module.__dict__["spark"] = spark
with add_pipeline_analysis_context(
spark=spark, dataflow_graph_id=dataflow_graph_id,
flow_name=None
):
diff --git a/python/pyspark/pipelines/init_cli.py
b/python/pyspark/pipelines/init_cli.py
index a1dbdfd9d558..18bbb70ed9c1 100644
--- a/python/pyspark/pipelines/init_cli.py
+++ b/python/pyspark/pipelines/init_cli.py
@@ -26,9 +26,7 @@ libraries:
"""
PYTHON_EXAMPLE = """from pyspark import pipelines as dp
-from pyspark.sql import DataFrame, SparkSession
-
-spark = SparkSession.active()
+from pyspark.sql import DataFrame
@dp.materialized_view
def example_python_materialized_view() -> DataFrame:
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
index c6b457ee04eb..f59994c9490b 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
@@ -267,9 +267,7 @@ trait APITest
name = "transformations/definition.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|@dp.append_flow(target = "c", name = "append_to_c")
|def flow():
@@ -296,9 +294,7 @@ trait APITest
name = "transformations/mv.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|@dp.materialized_view
|def src():
@@ -308,9 +304,7 @@ trait APITest
name = "transformations/st.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|@dp.materialized_view
|def a():
@@ -347,9 +341,7 @@ trait APITest
name = "transformations/definition.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|@dp.materialized_view
|def a():
@@ -374,6 +366,52 @@ trait APITest
}
/* Python Language Tests */
+ test("Python Pipeline with explicit spark assignment is backward
compatible") {
+ val pipelineSpec =
+ TestPipelineSpec(include = Seq("transformations/**"))
+ val pipelineConfig = TestPipelineConfiguration(pipelineSpec)
+ val sources = Seq(
+ PipelineSourceFile(
+ name = "transformations/definition.py",
+ contents = """
+ |from pyspark import pipelines as dp
+ |from pyspark.sql import SparkSession
+ |
+ |spark = SparkSession.active()
+ |
+ |@dp.materialized_view
+ |def mv():
+ | return spark.range(5)
+ |""".stripMargin))
+ val pipeline = createAndRunPipeline(pipelineConfig, sources)
+ awaitPipelineTermination(pipeline)
+
+ checkAnswer(spark.sql(s"SELECT * FROM mv"), Seq(Row(0), Row(1), Row(2),
Row(3), Row(4)))
+ }
+
+ test("Python Pipeline with spark session placeholder works as expected") {
+ val pipelineSpec =
+ TestPipelineSpec(include = Seq("transformations/**"))
+ val pipelineConfig = TestPipelineConfiguration(pipelineSpec)
+ val sources = Seq(
+ PipelineSourceFile(
+ name = "transformations/definition.py",
+ contents = """
+ |from pyspark import pipelines as dp
+ |from pyspark.sql import SparkSession
+ |
+ |spark: SparkSession
+ |
+ |@dp.materialized_view
+ |def mv():
+ | return spark.range(5)
+ |""".stripMargin))
+ val pipeline = createAndRunPipeline(pipelineConfig, sources)
+ awaitPipelineTermination(pipeline)
+
+ checkAnswer(spark.sql(s"SELECT * FROM mv"), Seq(Row(0), Row(1), Row(2),
Row(3), Row(4)))
+ }
+
test("Python Pipeline with materialized_view, create_streaming_table, and
append_flow") {
val pipelineSpec =
TestPipelineSpec(include = Seq("transformations/**"))
@@ -383,9 +421,7 @@ trait APITest
name = "transformations/st.py",
contents = s"""
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|dp.create_streaming_table(
| name = "a",
@@ -401,9 +437,7 @@ trait APITest
name = "transformations/mv.py",
contents = s"""
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|@dp.materialized_view(
| name = "src",
@@ -431,9 +465,7 @@ trait APITest
name = "transformations/definition.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|@dp.temporary_view(
| name = "view_1",
@@ -475,9 +507,7 @@ trait APITest
contents =
s"""
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
- |
- |spark = SparkSession.active()
+ |from pyspark.sql import DataFrame
|
|dp.create_sink(
| "mySink",
@@ -518,11 +548,9 @@ trait APITest
name = "transformations/definition.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
+ |from pyspark.sql import DataFrame
|from pyspark.sql.functions import col
|
- |spark = SparkSession.active()
- |
|@dp.materialized_view(partition_cols = ["id_mod"])
|def mv():
| return spark.range(5).withColumn("id_mod", col("id") %
2)
@@ -551,11 +579,9 @@ trait APITest
name = "transformations/definition.py",
contents = """
|from pyspark import pipelines as dp
- |from pyspark.sql import DataFrame, SparkSession
+ |from pyspark.sql import DataFrame
|from pyspark.sql.functions import col
|
- |spark = SparkSession.active()
- |
|@dp.materialized_view(cluster_by = ["cluster_col1"])
|def mv():
| df = spark.range(10)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]