This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 5d43557780d4 [SPARK-50930][ML][PYTHON][CONNECT] Support 
`PowerIterationClustering` on Connect
5d43557780d4 is described below

commit 5d43557780d4245bc8723ef94603d78d349ab1bb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 11:57:57 2025 +0800

    [SPARK-50930][ML][PYTHON][CONNECT] Support `PowerIterationClustering` on 
Connect
    
    ### What changes were proposed in this pull request?
    Support `PowerIterationClustering` on Connect
    
    ### Why are the changes needed?
    feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49792 from zhengruifeng/ml_connect_pic.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 5af4c764aa8b91146b25c911727ae08a1603104c)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       |  1 +
 .../ml/clustering/PowerIterationClustering.scala   | 28 +++++++++++++++++
 python/pyspark/ml/clustering.py                    | 15 +++++++++-
 python/pyspark/ml/tests/test_clustering.py         | 35 ++++++++++++++++++++++
 4 files changed, 78 insertions(+), 1 deletion(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 1dd431255996..247c9c912f5a 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -68,6 +68,7 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
 org.apache.spark.ml.clustering.GaussianMixtureModel
 org.apache.spark.ml.clustering.DistributedLDAModel
 org.apache.spark.ml.clustering.LocalLDAModel
+org.apache.spark.ml.clustering.PowerIterationClusteringWrapper
 
 # recommendation
 org.apache.spark.ml.recommendation.ALSModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 8b2ee955d6a5..6e7028d8f99e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.annotation.Since
+import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
@@ -191,3 +192,30 @@ object PowerIterationClustering extends 
DefaultParamsReadable[PowerIterationClus
   @Since("2.4.0")
   override def load(path: String): PowerIterationClustering = super.load(path)
 }
+
+private[spark] class PowerIterationClusteringWrapper(override val uid: String)
+  extends Transformer with PowerIterationClusteringParams with 
DefaultParamsWritable {
+
+  def this() = this(Identifiable.randomUID("PowerIterationClusteringWrapper"))
+
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    val pic = new PowerIterationClustering()
+      .setK($(k))
+      .setInitMode($(initMode))
+      .setMaxIter($(maxIter))
+      .setSrcCol($(srcCol))
+      .setDstCol($(dstCol))
+    get(weightCol) match {
+      case Some(w) if w.nonEmpty => pic.setWeightCol(w)
+      case _ =>
+    }
+    pic.assignClusters(dataset)
+  }
+
+  override def transformSchema(schema: StructType): StructType =
+    new StructType()
+      .add(StructField("id", LongType, nullable = false))
+      .add(StructField("cluster", IntegerType, nullable = false))
+
+  override def copy(extra: ParamMap): PowerIterationClusteringWrapper = 
defaultCopy(extra)
+}
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index e05c44fc18b1..dcd34ba365a5 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -2152,9 +2152,22 @@ class PowerIterationClustering(
             - id: Long
             - cluster: Int
         """
-        self._transfer_params_to_java()
         assert self._java_obj is not None
 
+        if is_remote():
+            from pyspark.ml.wrapper import JavaTransformer
+            from pyspark.ml.connect.serialize import serialize_ml_params
+
+            instance = JavaTransformer()
+            instance._java_obj = 
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
+            instance._serialized_ml_params = serialize_ml_params(  # type: 
ignore[attr-defined]
+                self,
+                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
+            )
+            return instance.transform(dataset)
+
+        self._transfer_params_to_java()
+
         jdf = self._java_obj.assignClusters(dataset._jdf)
         return DataFrame(jdf, dataset.sparkSession)
 
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index 42fe7a76256a..136a166d2218 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -36,6 +36,7 @@ from pyspark.ml.clustering import (
     LDAModel,
     LocalLDAModel,
     DistributedLDAModel,
+    PowerIterationClustering,
 )
 
 
@@ -461,6 +462,40 @@ class ClusteringTestsMixin:
             model2 = DistributedLDAModel.load(d)
             self.assertEqual(str(model), str(model2))
 
+    # TODO(SPARK-51080): Fix save/load for PowerIterationClustering
+    def test_power_iteration_clustering(self):
+        spark = self.spark
+
+        data = [
+            (1, 0, 0.5),
+            (2, 0, 0.5),
+            (2, 1, 0.7),
+            (3, 0, 0.5),
+            (3, 1, 0.7),
+            (3, 2, 0.9),
+            (4, 0, 0.5),
+            (4, 1, 0.7),
+            (4, 2, 0.9),
+            (4, 3, 1.1),
+            (5, 0, 0.5),
+            (5, 1, 0.7),
+            (5, 2, 0.9),
+            (5, 3, 1.1),
+            (5, 4, 1.3),
+        ]
+        df = spark.createDataFrame(data, ["src", "dst", 
"weight"]).repartition(1)
+
+        pic = PowerIterationClustering(k=2, weightCol="weight")
+        pic.setMaxIter(40)
+
+        self.assertEqual(pic.getK(), 2)
+        self.assertEqual(pic.getMaxIter(), 40)
+        self.assertEqual(pic.getWeightCol(), "weight")
+
+        assignments = pic.assignClusters(df)
+        self.assertEqual(assignments.columns, ["id", "cluster"])
+        self.assertEqual(assignments.count(), 6)
+
 
 class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
     def setUp(self) -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to