This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 40ccabfd681 [SPARK-44908][ML][CONNECT] Fix cross validator foldCol 
param functionality
40ccabfd681 is described below

commit 40ccabfd68141eabe8f9b9bf15acad9fc6b7dff1
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Wed Aug 23 18:19:15 2023 +0800

    [SPARK-44908][ML][CONNECT] Fix cross validator foldCol param functionality
    
    ### What changes were proposed in this pull request?
    
    Fix cross validator foldCol param functionality.
    In main branch the code calls `df.rdd` APIs but it is not supported in 
spark connect
    
    ### Why are the changes needed?
    
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42605 from WeichenXu123/fix-tuning-connect-foldCol.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
    (cherry picked from commit 0d1b5975b2d308c616312d53b9f7ad754348a266)
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 python/pyspark/ml/connect/tuning.py                | 24 ++++++---------------
 .../ml/tests/connect/test_legacy_mode_tuning.py    | 25 ++++++++++++++++++++++
 2 files changed, 32 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/ml/connect/tuning.py 
b/python/pyspark/ml/connect/tuning.py
index c22c31e84e8..871e448966c 100644
--- a/python/pyspark/ml/connect/tuning.py
+++ b/python/pyspark/ml/connect/tuning.py
@@ -42,8 +42,7 @@ from pyspark.ml.connect.io_utils import (
 )
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasParallelism, HasSeed
-from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
-from pyspark.sql.types import BooleanType
+from pyspark.sql.functions import col, lit, rand
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql import SparkSession
 
@@ -477,23 +476,14 @@ class CrossValidator(
                 train = df.filter(~condition)
                 datasets.append((train, validation))
         else:
-            # Use user-specified fold numbers.
-            def checker(foldNum: int) -> bool:
-                if foldNum < 0 or foldNum >= nFolds:
-                    raise ValueError(
-                        "Fold number must be in range [0, %s), but got %s." % 
(nFolds, foldNum)
-                    )
-                return True
-
-            checker_udf = UserDefinedFunction(checker, BooleanType())
+            # TODO:
+            #  Add verification that foldCol column values are in range [0, 
nFolds)
             for i in range(nFolds):
-                training = dataset.filter(checker_udf(dataset[foldCol]) & 
(col(foldCol) != lit(i)))
-                validation = dataset.filter(
-                    checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i))
-                )
-                if training.rdd.getNumPartitions() == 0 or 
len(training.take(1)) == 0:
+                training = dataset.filter(col(foldCol) != lit(i))
+                validation = dataset.filter(col(foldCol) == lit(i))
+                if training.isEmpty():
                     raise ValueError("The training data at fold %s is empty." 
% i)
-                if validation.rdd.getNumPartitions() == 0 or 
len(validation.take(1)) == 0:
+                if validation.isEmpty():
                     raise ValueError("The validation data at fold %s is 
empty." % i)
                 datasets.append((training, validation))
 
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index d6c813533d6..0ade227540c 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -246,6 +246,31 @@ class CrossValidatorTestsMixin:
             np.testing.assert_allclose(cv_model.avgMetrics, 
loaded_cv_model.avgMetrics)
             np.testing.assert_allclose(cv_model.stdMetrics, 
loaded_cv_model.stdMetrics)
 
+    def test_crossvalidator_with_fold_col(self):
+        sk_dataset = load_breast_cancer()
+
+        train_dataset = self.spark.createDataFrame(
+            zip(
+                sk_dataset.data.tolist(),
+                [int(t) for t in sk_dataset.target],
+                [int(i % 3) for i in range(len(sk_dataset.target))],
+            ),
+            schema="features: array<double>, label: long, fold: long",
+        )
+
+        lorv2 = LORV2(numTrainWorkers=2)
+
+        grid2 = ParamGridBuilder().addGrid(lorv2.maxIter, [2, 200]).build()
+        cv = CrossValidator(
+            estimator=lorv2,
+            estimatorParamMaps=grid2,
+            parallelism=2,
+            evaluator=BinaryClassificationEvaluator(),
+            foldCol="fold",
+            numFolds=3,
+        )
+        cv.fit(train_dataset)
+
 
 class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
     def setUp(self) -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to