This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4138efd47139 [SPARK-50944][ML][PYTHON][CONNECT] Support 
`KolmogorovSmirnovTest` on Connect
4138efd47139 is described below

commit 4138efd471390701793800060366b2870599bce6
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 11:56:00 2025 +0800

    [SPARK-50944][ML][PYTHON][CONNECT] Support `KolmogorovSmirnovTest` on 
Connect
    
    ### What changes were proposed in this pull request?
    Support `KolmogorovSmirnovTest` on Connect
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49791 from zhengruifeng/ml_connect_kst.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       |  1 +
 .../spark/ml/stat/KolmogorovSmirnovTest.scala      | 37 ++++++++++++++++++++-
 python/pyspark/ml/stat.py                          | 38 +++++++++++++++++-----
 python/pyspark/ml/tests/test_stat.py               | 38 +++++++++++++++++++++-
 4 files changed, 103 insertions(+), 11 deletions(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 84f3631e5475..1dd431255996 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -79,6 +79,7 @@ org.apache.spark.ml.fpm.PrefixSpanWrapper
 # stat
 org.apache.spark.ml.stat.ChiSquareTestWrapper
 org.apache.spark.ml.stat.CorrelationWrapper
+org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper
 
 # feature
 org.apache.spark.ml.feature.RFormulaModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
index f4a6b8b033db..2fc4a856b564 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
@@ -21,11 +21,15 @@ import scala.annotation.varargs
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.function.Function
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.HasInputCol
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
 
 /**
  * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
@@ -114,3 +118,34 @@ object KolmogorovSmirnovTest {
       testResult.pValue, testResult.statistic)))
   }
 }
+
+
+/**
+ * [[KolmogorovSmirnovTest]] is not an Estimator/Transformer and thus needs to 
be wrapped
+ * in a wrapper to be compatible with Spark Connect.
+ */
+private[spark] class KolmogorovSmirnovTestWrapper(override val uid: String)
+  extends Transformer with HasInputCol {
+
+  val paramsArray = new DoubleArrayParam(this, "paramsArray",
+    "The parameters to be used for the theoretical distribution.")
+
+  val distName = new Param[String](this, "distName",
+    "The name of the theoretical distribution to test against")
+
+  setDefault(paramsArray -> Array.emptyDoubleArray)
+
+  def this() = this(Identifiable.randomUID("KolmogorovSmirnovTestWrapper"))
+
+  override def transformSchema(schema: StructType): StructType = {
+    new StructType()
+      .add("pValue", DoubleType, nullable = false)
+      .add("statistic", DoubleType, nullable = false)
+  }
+
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    KolmogorovSmirnovTest.test(dataset, $(inputCol), $(distName), 
$(paramsArray).toIndexedSeq: _*)
+  }
+
+  override def copy(extra: ParamMap): KolmogorovSmirnovTestWrapper = 
defaultCopy(extra)
+}
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index ac11c73ce0f7..16fa9f7edb80 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -272,17 +272,37 @@ class KolmogorovSmirnovTest:
         >>> round(ksResult.statistic, 3)
         0.175
         """
-        from pyspark.core.context import SparkContext
+        if is_remote():
+            from pyspark.ml.wrapper import JavaTransformer
+            from pyspark.ml.connect.serialize import serialize_ml_params_values
 
-        sc = SparkContext._active_spark_context
-        assert sc is not None
+            instance = JavaTransformer()
+            instance._java_obj = 
"org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper"
+            serialized_ml_params = serialize_ml_params_values(
+                {"inputCol": sampleCol, "distName": distName, "paramsArray": 
list(params)},
+                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
+            )
+            instance._serialized_ml_params = serialized_ml_params  # type: 
ignore[attr-defined]
+            return instance.transform(dataset)
 
-        javaTestObj = getattr(_jvm(), 
"org.apache.spark.ml.stat.KolmogorovSmirnovTest")
-        dataset = _py2java(sc, dataset)
-        params = [float(param) for param in params]  # type: ignore[assignment]
-        return _java2py(
-            sc, javaTestObj.test(dataset, sampleCol, distName, 
_jvm().PythonUtils.toSeq(params))
-        )
+        else:
+            from pyspark.core.context import SparkContext
+
+            sc = SparkContext._active_spark_context
+            assert sc is not None
+
+            javaTestObj = getattr(_jvm(), 
"org.apache.spark.ml.stat.KolmogorovSmirnovTest")
+            dataset = _py2java(sc, dataset)
+            params = [float(param) for param in params]  # type: 
ignore[assignment]
+            return _java2py(
+                sc,
+                javaTestObj.test(
+                    dataset,
+                    sampleCol,
+                    distName,
+                    _jvm().PythonUtils.toSeq(params),
+                ),
+            )
 
 
 class Summarizer:
diff --git a/python/pyspark/ml/tests/test_stat.py 
b/python/pyspark/ml/tests/test_stat.py
index abe6d88a98ac..bd827e3a27f7 100644
--- a/python/pyspark/ml/tests/test_stat.py
+++ b/python/pyspark/ml/tests/test_stat.py
@@ -19,7 +19,7 @@ import numpy as np
 import unittest
 
 from pyspark.ml.linalg import Vectors
-from pyspark.ml.stat import ChiSquareTest, Correlation
+from pyspark.ml.stat import ChiSquareTest, Correlation, KolmogorovSmirnovTest
 from pyspark.sql import DataFrame
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
@@ -96,6 +96,42 @@ class StatTestsMixin:
             corr2,
         )
 
+    def test_kolmogorov_smirnov(self):
+        spark = self.spark
+
+        data = [[-1.0], [0.0], [1.0]]
+        df = spark.createDataFrame(data, ["sample"])
+
+        res1 = KolmogorovSmirnovTest.test(df, "sample", "norm", 0.0, 1.0)
+        self.assertEqual(res1.columns, ["pValue", "statistic"])
+        self.assertEqual(res1.count(), 1)
+
+        row = res1.head()
+        self.assertTrue(
+            np.allclose(row.pValue, 0.9999753186701124, atol=1e-4),
+            row.pValue,
+        )
+        self.assertTrue(
+            np.allclose(row.statistic, 0.1746780794018764, atol=1e-4),
+            row.statistic,
+        )
+
+        data2 = [[2.0], [3.0], [4.0]]
+        df2 = spark.createDataFrame(data2, ["sample"])
+        res2 = KolmogorovSmirnovTest.test(df2, "sample", "norm", 3, 1)
+        self.assertEqual(res2.columns, ["pValue", "statistic"])
+        self.assertEqual(res2.count(), 1)
+
+        row2 = res2.head()
+        self.assertTrue(
+            np.allclose(row2.pValue, 0.9999753186701124, atol=1e-4),
+            row2.pValue,
+        )
+        self.assertTrue(
+            np.allclose(row2.statistic, 0.1746780794018764, atol=1e-4),
+            row2.statistic,
+        )
+
 
 class StatTests(StatTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to