This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new be1afd504a44 [SPARK-48941][PYTHON][ML] Replace RDD read / write API
invocation with Dataframe read / write API
be1afd504a44 is described below
commit be1afd504a44ea6058f764e0adf7140eedf704db
Author: Weichen Xu <[email protected]>
AuthorDate: Mon Jul 22 21:18:54 2024 +0800
[SPARK-48941][PYTHON][ML] Replace RDD read / write API invocation with
Dataframe read / write API
### What changes were proposed in this pull request?
PysparkML: Replace RDD read / write API invocation with Dataframe read /
write API
### Why are the changes needed?
Follow-up of https://github.com/apache/spark/pull/47341
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47411 from WeichenXu123/SPARK-48909-follow-up.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
python/pyspark/ml/util.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index b9a2829a1ca0..5e7965554d82 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -464,7 +464,10 @@ class DefaultParamsWriter(MLWriter):
metadataJson = DefaultParamsWriter._get_metadata_to_save(
instance, sc, extraMetadata, paramMap
)
- sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
+ spark = SparkSession.getActiveSession()
+ spark.createDataFrame( # type: ignore[union-attr]
+ [(metadataJson,)], schema=["value"]
+ ).coalesce(1).write.text(metadataPath)
@staticmethod
def _get_metadata_to_save(
@@ -577,7 +580,8 @@ class DefaultParamsReader(MLReader[RL]):
If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
- metadataStr = sc.textFile(metadataPath, 1).first()
+ spark = SparkSession.getActiveSession()
+ metadataStr = spark.read.text(metadataPath).first()[0] # type:
ignore[union-attr,index]
loadedVals = DefaultParamsReader._parseMetaData(metadataStr,
expectedClassName)
return loadedVals
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]