This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 66c29209eee6 [SPARK-50995][ML][PYTHON][CONNECT] Support 
`clusterCenters` for KMeans and BisectingKMeans
66c29209eee6 is described below

commit 66c29209eee6c89bbbae0516d650f3898312281b
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Jan 27 09:22:54 2025 +0800

    [SPARK-50995][ML][PYTHON][CONNECT] Support `clusterCenters` for KMeans and 
BisectingKMeans
    
    ### What changes were proposed in this pull request?
    Support `clusterCenters` for KMeans and BisectingKMeans,
    
    To simplify the serde of `Array[Vector]`, combine it to a `Matrix`
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new API supported on connect
    
    ### How was this patch tested?
    added test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49680 from zhengruifeng/ml_connect_km_cluster.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../org/apache/spark/ml/clustering/BisectingKMeans.scala     |  5 ++++-
 .../main/scala/org/apache/spark/ml/clustering/KMeans.scala   |  3 +++
 python/pyspark/ml/clustering.py                              |  6 ++++--
 python/pyspark/ml/tests/test_clustering.py                   | 12 ++++++++++++
 .../main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala |  4 ++--
 5 files changed, 25 insertions(+), 5 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index d0e5cb42c41c..c1ef69e8b047 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
@@ -142,6 +142,9 @@ class BisectingKMeansModel private[ml] (
   @Since("2.0.0")
   def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
 
+  private[ml] def clusterCenterMatrix: Matrix =
+    Matrices.fromVectors(clusterCenters.toSeq)
+
   /**
    * Computes the sum of squared distances between the input points and their 
corresponding cluster
    * centers.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 17d34a277af2..e878e12f4df4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -187,6 +187,9 @@ class KMeansModel private[ml] (
   @Since("2.0.0")
   def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
 
+  private[ml] def clusterCenterMatrix: Matrix =
+    Matrices.fromVectors(clusterCenters.toSeq)
+
   /**
    * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this 
ML instance.
    *
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 8166cd41c834..74c9d8705796 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -686,7 +686,8 @@ class KMeansModel(
     @since("1.5.0")
     def clusterCenters(self) -> List[np.ndarray]:
         """Get the cluster centers, represented as a list of NumPy arrays."""
-        return [c.toArray() for c in self._call_java("clusterCenters")]
+        matrix = self._call_java("clusterCenterMatrix")
+        return [vec for vec in matrix.toArray()]
 
     @property
     @since("2.1.0")
@@ -1006,7 +1007,8 @@ class BisectingKMeansModel(
     @since("2.0.0")
     def clusterCenters(self) -> List[np.ndarray]:
         """Get the cluster centers, represented as a list of NumPy arrays."""
-        return [c.toArray() for c in self._call_java("clusterCenters")]
+        matrix = self._call_java("clusterCenterMatrix")
+        return [vec for vec in matrix.toArray()]
 
     @since("2.0.0")
     def computeCost(self, dataset: DataFrame) -> float:
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index 9a26b746f027..380342e337a0 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -69,6 +69,12 @@ class ClusteringTestsMixin:
 
         model = km.fit(df)
         self.assertEqual(km.uid, model.uid)
+
+        centers = model.clusterCenters()
+        self.assertEqual(len(centers), 2)
+        self.assertTrue(np.allclose(centers[0], [-0.372, -0.338], atol=1e-3), 
centers[0])
+        self.assertTrue(np.allclose(centers[1], [0.8625, 0.83375], atol=1e-3), 
centers[1])
+
         # TODO: support KMeansModel.numFeatures in Python
         # self.assertEqual(model.numFeatures, 2)
 
@@ -138,6 +144,12 @@ class ClusteringTestsMixin:
 
         model = bkm.fit(df)
         self.assertEqual(bkm.uid, model.uid)
+
+        centers = model.clusterCenters()
+        self.assertEqual(len(centers), 2)
+        self.assertTrue(np.allclose(centers[0], [-0.372, -0.338], atol=1e-3), 
centers[0])
+        self.assertTrue(np.allclose(centers[1], [0.8625, 0.83375], atol=1e-3), 
centers[1])
+
         # TODO: support KMeansModel.numFeatures in Python
         # self.assertEqual(model.numFeatures, 2)
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 9bf3c632b219..75aed57ae2d2 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -584,11 +584,11 @@ private[ml] object MLUtils {
     (classOf[LinearRegressionTrainingSummary], Set("objectiveHistory", 
"totalIterations")),
 
     // Clustering Models
-    (classOf[KMeansModel], Set("predict", "numFeatures", "clusterCenters")),
+    (classOf[KMeansModel], Set("predict", "numFeatures", 
"clusterCenterMatrix")),
     (classOf[KMeansSummary], Set("trainingCost")),
     (
       classOf[BisectingKMeansModel],
-      Set("predict", "numFeatures", "clusterCenters", "computeCost")),
+      Set("predict", "numFeatures", "clusterCenterMatrix", "computeCost")),
     (classOf[BisectingKMeansSummary], Set("trainingCost")),
     (
       classOf[GaussianMixtureModel],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to