This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new d18c899be771 [SPARK-50929][ML][PYTHON][CONNECT] Support `LDA` on Connect d18c899be771 is described below commit d18c899be7714aa3bf63118a989078a3c32091bb Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Jan 27 09:19:52 2025 +0800 [SPARK-50929][ML][PYTHON][CONNECT] Support `LDA` on Connect ### What changes were proposed in this pull request? Support `LDA` on Connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49679 from zhengruifeng/ml_connect_lda. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit b6b00e87b00be9c8ca7103d006c900caf0cb032b) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../services/org.apache.spark.ml.Estimator | 1 + .../services/org.apache.spark.ml.Transformer | 2 + .../scala/org/apache/spark/ml/clustering/LDA.scala | 4 + python/pyspark/ml/clustering.py | 1 + python/pyspark/ml/tests/test_clustering.py | 139 ++++++++++++++++++++- .../org/apache/spark/sql/connect/ml/MLUtils.scala | 13 ++ 6 files changed, 159 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator index 9c1a1f5a19a6..97526bf1a0c0 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator @@ -37,6 +37,7 @@ org.apache.spark.ml.regression.GBTRegressor org.apache.spark.ml.clustering.KMeans org.apache.spark.ml.clustering.BisectingKMeans org.apache.spark.ml.clustering.GaussianMixture +org.apache.spark.ml.clustering.LDA # recommendation org.apache.spark.ml.recommendation.ALS diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index 3f1ae52aaaf6..c6faa54c147b 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -53,6 +53,8 @@ org.apache.spark.ml.regression.GBTRegressionModel org.apache.spark.ml.clustering.KMeansModel org.apache.spark.ml.clustering.BisectingKMeansModel org.apache.spark.ml.clustering.GaussianMixtureModel +org.apache.spark.ml.clustering.DistributedLDAModel +org.apache.spark.ml.clustering.LocalLDAModel # recommendation org.apache.spark.ml.recommendation.ALSModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index b3d3c84db051..3fce96fbfbb0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -617,6 +617,8 @@ class LocalLDAModel private[ml] ( sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null) + oldLocalModel.setSeed(getSeed) @Since("1.6.0") @@ -713,6 +715,8 @@ class DistributedLDAModel private[ml] ( private var oldLocalModelOption: Option[OldLocalLDAModel]) extends LDAModel(uid, vocabSize, sparkSession) { + private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null, None) + override private[clustering] def oldLocalModel: OldLocalLDAModel = { if (oldLocalModelOption.isEmpty) { oldLocalModelOption = Some(oldDistributedModel.toLocal) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6cd508a9e950..8166cd41c834 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -1511,6 +1511,7 @@ class LDAModel(JavaModel, _LDAParams): return self._call_java("logPerplexity", dataset) @since("2.0.0") + @try_remote_attribute_relation def describeTopics(self, maxTermsPerTopic: int = 10) -> DataFrame: """ Return the topics described by their top-weighted terms. diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index e6013d10fa8e..9a26b746f027 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -20,7 +20,7 @@ import unittest import numpy as np -from pyspark.ml.linalg import Vectors +from pyspark.ml.linalg import Vectors, SparseVector from pyspark.sql import SparkSession from pyspark.ml.clustering import ( KMeans, @@ -32,6 +32,10 @@ from pyspark.ml.clustering import ( GaussianMixture, GaussianMixtureModel, GaussianMixtureSummary, + LDA, + LDAModel, + LocalLDAModel, + DistributedLDAModel, ) @@ -264,6 +268,139 @@ class ClusteringTestsMixin: model2 = GaussianMixtureModel.load(d) self.assertEqual(str(model), str(model2)) + def test_local_lda(self): + spark = self.spark + df = ( + spark.createDataFrame( + [ + [1, Vectors.dense([0.0, 1.0])], + [2, SparseVector(2, {0: 1.0})], + ], + ["id", "features"], + ) + .coalesce(1) + .sortWithinPartitions("id") + ) + + lda = LDA(k=2, optimizer="online", seed=1) + lda.setMaxIter(1) + self.assertEqual(lda.getK(), 2) + self.assertEqual(lda.getOptimizer(), "online") + self.assertEqual(lda.getMaxIter(), 1) + self.assertEqual(lda.getSeed(), 1) + + model = lda.fit(df) + self.assertEqual(lda.uid, model.uid) + self.assertIsInstance(model, LDAModel) + self.assertIsInstance(model, LocalLDAModel) + self.assertNotIsInstance(model, DistributedLDAModel) + self.assertFalse(model.isDistributed()) + + dc = model.estimatedDocConcentration() + self.assertTrue(np.allclose(dc.toArray(), [0.5, 0.5], atol=1e-4), dc) + topics = model.topicsMatrix() + self.assertTrue( + np.allclose( + topics.toArray(), [[1.20296728, 1.15740442], [0.99357675, 1.02993164]], atol=1e-4 + ), + topics, + ) + + ll = model.logLikelihood(df) + self.assertTrue(np.allclose(ll, -3.2125122434040088, atol=1e-4), ll) + lp = model.logPerplexity(df) + self.assertTrue(np.allclose(lp, 1.6062561217020044, atol=1e-4), lp) + dt = model.describeTopics() + self.assertEqual(dt.columns, ["topic", "termIndices", "termWeights"]) + self.assertEqual(dt.count(), 2) + + # LocalLDAModel specific methods + self.assertEqual(model.vocabSize(), 2) + + output = model.transform(df) + expected_cols = ["id", "features", "topicDistribution"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 2) + + # save & load + with tempfile.TemporaryDirectory(prefix="local_lda") as d: + lda.write().overwrite().save(d) + lda2 = LDA.load(d) + self.assertEqual(str(lda), str(lda2)) + + model.write().overwrite().save(d) + model2 = LocalLDAModel.load(d) + self.assertEqual(str(model), str(model2)) + + def test_distributed_lda(self): + spark = self.spark + df = ( + spark.createDataFrame( + [ + [1, Vectors.dense([0.0, 1.0])], + [2, SparseVector(2, {0: 1.0})], + ], + ["id", "features"], + ) + .coalesce(1) + .sortWithinPartitions("id") + ) + + lda = LDA(k=2, optimizer="em", seed=1) + lda.setMaxIter(1) + + self.assertEqual(lda.getK(), 2) + self.assertEqual(lda.getOptimizer(), "em") + self.assertEqual(lda.getMaxIter(), 1) + self.assertEqual(lda.getSeed(), 1) + + model = lda.fit(df) + self.assertEqual(lda.uid, model.uid) + self.assertIsInstance(model, LDAModel) + self.assertNotIsInstance(model, LocalLDAModel) + self.assertIsInstance(model, DistributedLDAModel) + + dc = model.estimatedDocConcentration() + self.assertTrue(np.allclose(dc.toArray(), [26.0, 26.0], atol=1e-4), dc) + topics = model.topicsMatrix() + self.assertTrue( + np.allclose( + topics.toArray(), [[0.39149926, 0.60850074], [0.60991237, 0.39008763]], atol=1e-4 + ), + topics, + ) + + ll = model.logLikelihood(df) + self.assertTrue(np.allclose(ll, -3.719138517085772, atol=1e-4), ll) + lp = model.logPerplexity(df) + self.assertTrue(np.allclose(lp, 1.859569258542886, atol=1e-4), lp) + + dt = model.describeTopics() + self.assertEqual(dt.columns, ["topic", "termIndices", "termWeights"]) + self.assertEqual(dt.count(), 2) + + # DistributedLDAModel specific methods + ll = model.trainingLogLikelihood() + self.assertTrue(np.allclose(ll, -1.3847360462201639, atol=1e-4), ll) + lp = model.logPrior() + self.assertTrue(np.allclose(lp, -69.59963186898915, atol=1e-4), lp) + model.getCheckpointFiles() + + output = model.transform(df) + expected_cols = ["id", "features", "topicDistribution"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 2) + + # save & load + with tempfile.TemporaryDirectory(prefix="distributed_lda") as d: + lda.write().overwrite().save(d) + lda2 = LDA.load(d) + self.assertEqual(str(lda), str(lda2)) + + model.write().overwrite().save(d) + model2 = DistributedLDAModel.load(d) + self.assertEqual(str(model), str(model2)) + class ClusteringTests(ClusteringTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index fbcbf8f3f204..9bf3c632b219 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -594,6 +594,19 @@ private[ml] object MLUtils { classOf[GaussianMixtureModel], Set("predict", "numFeatures", "weights", "gaussians", "predictProbability", "gaussiansDF")), (classOf[GaussianMixtureSummary], Set("probability", "probabilityCol", "logLikelihood")), + ( + classOf[LDAModel], + Set( + "estimatedDocConcentration", + "topicsMatrix", + "isDistributed", + "logLikelihood", + "logPerplexity", + "describeTopics")), + (classOf[LocalLDAModel], Set("vocabSize")), + ( + classOf[DistributedLDAModel], + Set("trainingLogLikelihood", "logPrior", "getCheckpointFiles")), // Recommendation Models ( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org