This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 40ccabfd681 [SPARK-44908][ML][CONNECT] Fix cross validator foldCol param functionality 40ccabfd681 is described below commit 40ccabfd68141eabe8f9b9bf15acad9fc6b7dff1 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Wed Aug 23 18:19:15 2023 +0800 [SPARK-44908][ML][CONNECT] Fix cross validator foldCol param functionality ### What changes were proposed in this pull request? Fix cross validator foldCol param functionality. In main branch the code calls `df.rdd` APIs but it is not supported in spark connect ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42605 from WeichenXu123/fix-tuning-connect-foldCol. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> (cherry picked from commit 0d1b5975b2d308c616312d53b9f7ad754348a266) Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/ml/connect/tuning.py | 24 ++++++--------------- .../ml/tests/connect/test_legacy_mode_tuning.py | 25 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index c22c31e84e8..871e448966c 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -42,8 +42,7 @@ from pyspark.ml.connect.io_utils import ( ) from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasParallelism, HasSeed -from pyspark.sql.functions import col, lit, rand, UserDefinedFunction -from pyspark.sql.types import BooleanType +from pyspark.sql.functions import col, lit, rand from pyspark.sql.dataframe import DataFrame from pyspark.sql import SparkSession @@ -477,23 +476,14 @@ class CrossValidator( train = df.filter(~condition) datasets.append((train, validation)) else: - # Use user-specified fold numbers. - def checker(foldNum: int) -> bool: - if foldNum < 0 or foldNum >= nFolds: - raise ValueError( - "Fold number must be in range [0, %s), but got %s." % (nFolds, foldNum) - ) - return True - - checker_udf = UserDefinedFunction(checker, BooleanType()) + # TODO: + # Add verification that foldCol column values are in range [0, nFolds) for i in range(nFolds): - training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i))) - validation = dataset.filter( - checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i)) - ) - if training.rdd.getNumPartitions() == 0 or len(training.take(1)) == 0: + training = dataset.filter(col(foldCol) != lit(i)) + validation = dataset.filter(col(foldCol) == lit(i)) + if training.isEmpty(): raise ValueError("The training data at fold %s is empty." % i) - if validation.rdd.getNumPartitions() == 0 or len(validation.take(1)) == 0: + if validation.isEmpty(): raise ValueError("The validation data at fold %s is empty." % i) datasets.append((training, validation)) diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py index d6c813533d6..0ade227540c 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py @@ -246,6 +246,31 @@ class CrossValidatorTestsMixin: np.testing.assert_allclose(cv_model.avgMetrics, loaded_cv_model.avgMetrics) np.testing.assert_allclose(cv_model.stdMetrics, loaded_cv_model.stdMetrics) + def test_crossvalidator_with_fold_col(self): + sk_dataset = load_breast_cancer() + + train_dataset = self.spark.createDataFrame( + zip( + sk_dataset.data.tolist(), + [int(t) for t in sk_dataset.target], + [int(i % 3) for i in range(len(sk_dataset.target))], + ), + schema="features: array<double>, label: long, fold: long", + ) + + lorv2 = LORV2(numTrainWorkers=2) + + grid2 = ParamGridBuilder().addGrid(lorv2.maxIter, [2, 200]).build() + cv = CrossValidator( + estimator=lorv2, + estimatorParamMaps=grid2, + parallelism=2, + evaluator=BinaryClassificationEvaluator(), + foldCol="fold", + numFolds=3, + ) + cv.fit(train_dataset) + class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase): def setUp(self) -> None: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org