This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 5d43557780d4 [SPARK-50930][ML][PYTHON][CONNECT] Support
`PowerIterationClustering` on Connect
5d43557780d4 is described below
commit 5d43557780d4245bc8723ef94603d78d349ab1bb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 11:57:57 2025 +0800
[SPARK-50930][ML][PYTHON][CONNECT] Support `PowerIterationClustering` on
Connect
### What changes were proposed in this pull request?
Support `PowerIterationClustering` on Connect
### Why are the changes needed?
feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49792 from zhengruifeng/ml_connect_pic.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 5af4c764aa8b91146b25c911727ae08a1603104c)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Transformer | 1 +
.../ml/clustering/PowerIterationClustering.scala | 28 +++++++++++++++++
python/pyspark/ml/clustering.py | 15 +++++++++-
python/pyspark/ml/tests/test_clustering.py | 35 ++++++++++++++++++++++
4 files changed, 78 insertions(+), 1 deletion(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 1dd431255996..247c9c912f5a 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -68,6 +68,7 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
org.apache.spark.ml.clustering.GaussianMixtureModel
org.apache.spark.ml.clustering.DistributedLDAModel
org.apache.spark.ml.clustering.LocalLDAModel
+org.apache.spark.ml.clustering.PowerIterationClusteringWrapper
# recommendation
org.apache.spark.ml.recommendation.ALSModel
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 8b2ee955d6a5..6e7028d8f99e 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.clustering
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
@@ -191,3 +192,30 @@ object PowerIterationClustering extends
DefaultParamsReadable[PowerIterationClus
@Since("2.4.0")
override def load(path: String): PowerIterationClustering = super.load(path)
}
+
+private[spark] class PowerIterationClusteringWrapper(override val uid: String)
+ extends Transformer with PowerIterationClusteringParams with
DefaultParamsWritable {
+
+ def this() = this(Identifiable.randomUID("PowerIterationClusteringWrapper"))
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val pic = new PowerIterationClustering()
+ .setK($(k))
+ .setInitMode($(initMode))
+ .setMaxIter($(maxIter))
+ .setSrcCol($(srcCol))
+ .setDstCol($(dstCol))
+ get(weightCol) match {
+ case Some(w) if w.nonEmpty => pic.setWeightCol(w)
+ case _ =>
+ }
+ pic.assignClusters(dataset)
+ }
+
+ override def transformSchema(schema: StructType): StructType =
+ new StructType()
+ .add(StructField("id", LongType, nullable = false))
+ .add(StructField("cluster", IntegerType, nullable = false))
+
+ override def copy(extra: ParamMap): PowerIterationClusteringWrapper =
defaultCopy(extra)
+}
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index e05c44fc18b1..dcd34ba365a5 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -2152,9 +2152,22 @@ class PowerIterationClustering(
- id: Long
- cluster: Int
"""
- self._transfer_params_to_java()
assert self._java_obj is not None
+ if is_remote():
+ from pyspark.ml.wrapper import JavaTransformer
+ from pyspark.ml.connect.serialize import serialize_ml_params
+
+ instance = JavaTransformer()
+ instance._java_obj =
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
+ instance._serialized_ml_params = serialize_ml_params( # type:
ignore[attr-defined]
+ self,
+ dataset.sparkSession.client, # type: ignore[arg-type,operator]
+ )
+ return instance.transform(dataset)
+
+ self._transfer_params_to_java()
+
jdf = self._java_obj.assignClusters(dataset._jdf)
return DataFrame(jdf, dataset.sparkSession)
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index 42fe7a76256a..136a166d2218 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -36,6 +36,7 @@ from pyspark.ml.clustering import (
LDAModel,
LocalLDAModel,
DistributedLDAModel,
+ PowerIterationClustering,
)
@@ -461,6 +462,40 @@ class ClusteringTestsMixin:
model2 = DistributedLDAModel.load(d)
self.assertEqual(str(model), str(model2))
+ # TODO(SPARK-51080): Fix save/load for PowerIterationClustering
+ def test_power_iteration_clustering(self):
+ spark = self.spark
+
+ data = [
+ (1, 0, 0.5),
+ (2, 0, 0.5),
+ (2, 1, 0.7),
+ (3, 0, 0.5),
+ (3, 1, 0.7),
+ (3, 2, 0.9),
+ (4, 0, 0.5),
+ (4, 1, 0.7),
+ (4, 2, 0.9),
+ (4, 3, 1.1),
+ (5, 0, 0.5),
+ (5, 1, 0.7),
+ (5, 2, 0.9),
+ (5, 3, 1.1),
+ (5, 4, 1.3),
+ ]
+ df = spark.createDataFrame(data, ["src", "dst",
"weight"]).repartition(1)
+
+ pic = PowerIterationClustering(k=2, weightCol="weight")
+ pic.setMaxIter(40)
+
+ self.assertEqual(pic.getK(), 2)
+ self.assertEqual(pic.getMaxIter(), 40)
+ self.assertEqual(pic.getWeightCol(), "weight")
+
+ assignments = pic.assignClusters(df)
+ self.assertEqual(assignments.columns, ["id", "cluster"])
+ self.assertEqual(assignments.count(), 6)
+
class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
def setUp(self) -> None:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]