This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 87f957dea86 [SPARK-35542][ML] Fix: Bucketizer created for multiple columns with parameters splitsArray, inputCols and outputCols can not be loaded after saving it 87f957dea86 is described below commit 87f957dea86fe1b8c5979e499b5400866b235e43 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Fri Aug 19 12:26:34 2022 +0800 [SPARK-35542][ML] Fix: Bucketizer created for multiple columns with parameters splitsArray, inputCols and outputCols can not be loaded after saving it Signed-off-by: Weichen Xu <weichen.xudatabricks.com> ### What changes were proposed in this pull request? Fix: Bucketizer created for multiple columns with parameters splitsArray, inputCols and outputCols can not be loaded after saving it ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #37568 from WeichenXu123/SPARK-35542. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> (cherry picked from commit 876ce6a5df118095de51c3c4789d6db6da95eb23) Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/ml/tests/test_persistence.py | 17 ++++++++++++++++- python/pyspark/ml/wrapper.py | 6 +++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index 4f09a49dd04..0b54540f06d 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -32,7 +32,7 @@ from pyspark.ml.classification import ( OneVsRestModel, ) from pyspark.ml.clustering import KMeans -from pyspark.ml.feature import Binarizer, HashingTF, PCA +from pyspark.ml.feature import Binarizer, Bucketizer, HashingTF, PCA from pyspark.ml.linalg import Vectors from pyspark.ml.param import Params from pyspark.ml.pipeline import Pipeline, PipelineModel @@ -518,6 +518,21 @@ class PersistenceTest(SparkSessionTestCase): ) reader.getAndSetParams(lr, loadedMetadata) + # Test for SPARK-35542 fix. + def test_save_and_load_on_nested_list_params(self): + temp_path = tempfile.mkdtemp() + splitsArray = [ + [-float("inf"), 0.5, 1.4, float("inf")], + [-float("inf"), 0.1, 1.2, float("inf")], + ] + bucketizer = Bucketizer( + splitsArray=splitsArray, inputCols=["values", "values"], outputCols=["b1", "b2"] + ) + savePath = temp_path + "/bk" + bucketizer.write().overwrite().save(savePath) + loadedBucketizer = Bucketizer.load(savePath) + assert loadedBucketizer.getSplitsArray() == splitsArray + if __name__ == "__main__": from pyspark.ml.tests.test_persistence import * # noqa: F401 diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 7853e766244..32856540d6d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -220,7 +220,11 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta): java_param = self._java_obj.getParam(param.name) # SPARK-14931: Only check set params back to avoid default params mismatch. if self._java_obj.isSet(java_param): - value = _java2py(sc, self._java_obj.getOrDefault(java_param)) + java_value = self._java_obj.getOrDefault(java_param) + if param.typeConverter.__name__.startswith("toList"): + value = [_java2py(sc, x) for x in list(java_value)] + else: + value = _java2py(sc, java_value) self._set(**{param.name: value}) # SPARK-10931: Temporary fix for params that have a default in Java if self._java_obj.hasDefault(java_param) and not self.isDefined(param): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org