This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4138efd47139 [SPARK-50944][ML][PYTHON][CONNECT] Support
`KolmogorovSmirnovTest` on Connect
4138efd47139 is described below
commit 4138efd471390701793800060366b2870599bce6
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 11:56:00 2025 +0800
[SPARK-50944][ML][PYTHON][CONNECT] Support `KolmogorovSmirnovTest` on
Connect
### What changes were proposed in this pull request?
Support `KolmogorovSmirnovTest` on Connect
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
new test
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49791 from zhengruifeng/ml_connect_kst.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Transformer | 1 +
.../spark/ml/stat/KolmogorovSmirnovTest.scala | 37 ++++++++++++++++++++-
python/pyspark/ml/stat.py | 38 +++++++++++++++++-----
python/pyspark/ml/tests/test_stat.py | 38 +++++++++++++++++++++-
4 files changed, 103 insertions(+), 11 deletions(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 84f3631e5475..1dd431255996 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -79,6 +79,7 @@ org.apache.spark.ml.fpm.PrefixSpanWrapper
# stat
org.apache.spark.ml.stat.ChiSquareTestWrapper
org.apache.spark.ml.stat.CorrelationWrapper
+org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper
# feature
org.apache.spark.ml.feature.RFormulaModel
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
index f4a6b8b033db..2fc4a856b564 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
@@ -21,11 +21,15 @@ import scala.annotation.varargs
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.function.Function
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.HasInputCol
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
/**
* Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
@@ -114,3 +118,34 @@ object KolmogorovSmirnovTest {
testResult.pValue, testResult.statistic)))
}
}
+
+
+/**
+ * [[KolmogorovSmirnovTest]] is not an Estimator/Transformer and thus needs to
be wrapped
+ * in a wrapper to be compatible with Spark Connect.
+ */
+private[spark] class KolmogorovSmirnovTestWrapper(override val uid: String)
+ extends Transformer with HasInputCol {
+
+ val paramsArray = new DoubleArrayParam(this, "paramsArray",
+ "The parameters to be used for the theoretical distribution.")
+
+ val distName = new Param[String](this, "distName",
+ "The name of the theoretical distribution to test against")
+
+ setDefault(paramsArray -> Array.emptyDoubleArray)
+
+ def this() = this(Identifiable.randomUID("KolmogorovSmirnovTestWrapper"))
+
+ override def transformSchema(schema: StructType): StructType = {
+ new StructType()
+ .add("pValue", DoubleType, nullable = false)
+ .add("statistic", DoubleType, nullable = false)
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ KolmogorovSmirnovTest.test(dataset, $(inputCol), $(distName),
$(paramsArray).toIndexedSeq: _*)
+ }
+
+ override def copy(extra: ParamMap): KolmogorovSmirnovTestWrapper =
defaultCopy(extra)
+}
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index ac11c73ce0f7..16fa9f7edb80 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -272,17 +272,37 @@ class KolmogorovSmirnovTest:
>>> round(ksResult.statistic, 3)
0.175
"""
- from pyspark.core.context import SparkContext
+ if is_remote():
+ from pyspark.ml.wrapper import JavaTransformer
+ from pyspark.ml.connect.serialize import serialize_ml_params_values
- sc = SparkContext._active_spark_context
- assert sc is not None
+ instance = JavaTransformer()
+ instance._java_obj =
"org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper"
+ serialized_ml_params = serialize_ml_params_values(
+ {"inputCol": sampleCol, "distName": distName, "paramsArray":
list(params)},
+ dataset.sparkSession.client, # type: ignore[arg-type,operator]
+ )
+ instance._serialized_ml_params = serialized_ml_params # type:
ignore[attr-defined]
+ return instance.transform(dataset)
- javaTestObj = getattr(_jvm(),
"org.apache.spark.ml.stat.KolmogorovSmirnovTest")
- dataset = _py2java(sc, dataset)
- params = [float(param) for param in params] # type: ignore[assignment]
- return _java2py(
- sc, javaTestObj.test(dataset, sampleCol, distName,
_jvm().PythonUtils.toSeq(params))
- )
+ else:
+ from pyspark.core.context import SparkContext
+
+ sc = SparkContext._active_spark_context
+ assert sc is not None
+
+ javaTestObj = getattr(_jvm(),
"org.apache.spark.ml.stat.KolmogorovSmirnovTest")
+ dataset = _py2java(sc, dataset)
+ params = [float(param) for param in params] # type:
ignore[assignment]
+ return _java2py(
+ sc,
+ javaTestObj.test(
+ dataset,
+ sampleCol,
+ distName,
+ _jvm().PythonUtils.toSeq(params),
+ ),
+ )
class Summarizer:
diff --git a/python/pyspark/ml/tests/test_stat.py
b/python/pyspark/ml/tests/test_stat.py
index abe6d88a98ac..bd827e3a27f7 100644
--- a/python/pyspark/ml/tests/test_stat.py
+++ b/python/pyspark/ml/tests/test_stat.py
@@ -19,7 +19,7 @@ import numpy as np
import unittest
from pyspark.ml.linalg import Vectors
-from pyspark.ml.stat import ChiSquareTest, Correlation
+from pyspark.ml.stat import ChiSquareTest, Correlation, KolmogorovSmirnovTest
from pyspark.sql import DataFrame
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -96,6 +96,42 @@ class StatTestsMixin:
corr2,
)
+ def test_kolmogorov_smirnov(self):
+ spark = self.spark
+
+ data = [[-1.0], [0.0], [1.0]]
+ df = spark.createDataFrame(data, ["sample"])
+
+ res1 = KolmogorovSmirnovTest.test(df, "sample", "norm", 0.0, 1.0)
+ self.assertEqual(res1.columns, ["pValue", "statistic"])
+ self.assertEqual(res1.count(), 1)
+
+ row = res1.head()
+ self.assertTrue(
+ np.allclose(row.pValue, 0.9999753186701124, atol=1e-4),
+ row.pValue,
+ )
+ self.assertTrue(
+ np.allclose(row.statistic, 0.1746780794018764, atol=1e-4),
+ row.statistic,
+ )
+
+ data2 = [[2.0], [3.0], [4.0]]
+ df2 = spark.createDataFrame(data2, ["sample"])
+ res2 = KolmogorovSmirnovTest.test(df2, "sample", "norm", 3, 1)
+ self.assertEqual(res2.columns, ["pValue", "statistic"])
+ self.assertEqual(res2.count(), 1)
+
+ row2 = res2.head()
+ self.assertTrue(
+ np.allclose(row2.pValue, 0.9999753186701124, atol=1e-4),
+ row2.pValue,
+ )
+ self.assertTrue(
+ np.allclose(row2.statistic, 0.1746780794018764, atol=1e-4),
+ row2.statistic,
+ )
+
class StatTests(StatTestsMixin, ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]