This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fba4c8c20e52 [SPARK-48970][PYTHON][ML] Avoid using
SparkSession.getActiveSession in spark ML reader/writer
fba4c8c20e52 is described below
commit fba4c8c20e523c9a441f007442efd616320e7be4
Author: Weichen Xu <[email protected]>
AuthorDate: Tue Jul 23 19:19:28 2024 +0800
[SPARK-48970][PYTHON][ML] Avoid using SparkSession.getActiveSession in
spark ML reader/writer
### What changes were proposed in this pull request?
`SparkSession.getActiveSession` is thread-local session, but spark ML
reader / writer might be executed in different threads which causes
`SparkSession.getActiveSession` returning None.
### Why are the changes needed?
It fixes the bug like:
```
spark = SparkSession.getActiveSession()
> spark.createDataFrame( # type: ignore[union-attr]
[(metadataJson,)], schema=["value"]
).coalesce(1).write.text(metadataPath)
E AttributeError: 'NoneType' object has no attribute 'createDataFrame'
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47453 from WeichenXu123/SPARK-48970.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +-
python/pyspark/ml/util.py | 12 ++++++------
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 021595f76c24..c127575e1470 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -588,7 +588,7 @@ private[ml] object DefaultParamsReader {
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String =
""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
- val spark = SparkSession.getActiveSession.get
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadataStr = spark.read.text(metadataPath).first().getString(0)
parseMetadata(metadataStr, expectedClassName)
}
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 5e7965554d82..89e2f9631564 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -464,10 +464,10 @@ class DefaultParamsWriter(MLWriter):
metadataJson = DefaultParamsWriter._get_metadata_to_save(
instance, sc, extraMetadata, paramMap
)
- spark = SparkSession.getActiveSession()
- spark.createDataFrame( # type: ignore[union-attr]
- [(metadataJson,)], schema=["value"]
- ).coalesce(1).write.text(metadataPath)
+ spark = SparkSession._getActiveSessionOrCreate()
+ spark.createDataFrame([(metadataJson,)],
schema=["value"]).coalesce(1).write.text(
+ metadataPath
+ )
@staticmethod
def _get_metadata_to_save(
@@ -580,8 +580,8 @@ class DefaultParamsReader(MLReader[RL]):
If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
- spark = SparkSession.getActiveSession()
- metadataStr = spark.read.text(metadataPath).first()[0] # type:
ignore[union-attr,index]
+ spark = SparkSession._getActiveSessionOrCreate()
+ metadataStr = spark.read.text(metadataPath).first()[0] # type:
ignore[index]
loadedVals = DefaultParamsReader._parseMetaData(metadataStr,
expectedClassName)
return loadedVals
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]