This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 28bbb5481c59 [SPARK-50878][ML][PYTHON][CONNECT] Support ALS on Connect
28bbb5481c59 is described below

commit 28bbb5481c5986931a480f4e1efafa864648aed6
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 20 20:52:39 2025 +0800

    [SPARK-50878][ML][PYTHON][CONNECT] Support ALS on Connect
    
    ### What changes were proposed in this pull request?
    Support ALS on Connect
    
    ### Why are the changes needed?
    For feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, ALS supported
    
    ### How was this patch tested?
    Added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49570 from zhengruifeng/ml_connect_als.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 78e4b79a15953e3d825e29c4cf54cb5cd4c48a12)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |  1 +
 .../services/org.apache.spark.ml.Estimator         |  4 ++
 python/pyspark/ml/recommendation.py                |  8 ++-
 .../pyspark/ml/tests/connect/test_parity_als.py    | 30 ++++-----
 python/pyspark/ml/tests/test_als.py                | 72 +++++++++++++++++++++-
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  9 ++-
 6 files changed, 106 insertions(+), 18 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f34a33dd4b69..cacd4a83bbe4 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1118,6 +1118,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_connect_classification",
         "pyspark.ml.tests.connect.test_connect_pipeline",
         "pyspark.ml.tests.connect.test_connect_tuning",
+        "pyspark.ml.tests.connect.test_parity_als",
         "pyspark.ml.tests.connect.test_parity_classification",
         "pyspark.ml.tests.connect.test_parity_regression",
         "pyspark.ml.tests.connect.test_parity_clustering",
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index 37b9c7e6aeb8..a7d7d3da9df3 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -35,3 +35,7 @@ org.apache.spark.ml.regression.GBTRegressor
 # clustering
 org.apache.spark.ml.clustering.KMeans
 org.apache.spark.ml.clustering.BisectingKMeans
+
+
+# recommendation
+org.apache.spark.ml.recommendation.ALS
diff --git a/python/pyspark/ml/recommendation.py 
b/python/pyspark/ml/recommendation.py
index 873140e51afb..d11990634593 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -30,7 +30,7 @@ from pyspark.ml.param.shared import (
 from pyspark.ml.wrapper import JavaEstimator, JavaModel
 from pyspark.ml.common import inherit_doc
 from pyspark.ml.param import Params, TypeConverters, Param
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable
+from pyspark.ml.util import JavaMLWritable, JavaMLReadable, 
try_remote_attribute_relation
 from pyspark.sql import DataFrame
 
 if TYPE_CHECKING:
@@ -617,6 +617,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable["ALSMo
 
     @property
     @since("1.4.0")
+    @try_remote_attribute_relation
     def userFactors(self) -> DataFrame:
         """
         a DataFrame that stores user factors in two columns: `id` and
@@ -626,6 +627,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable["ALSMo
 
     @property
     @since("1.4.0")
+    @try_remote_attribute_relation
     def itemFactors(self) -> DataFrame:
         """
         a DataFrame that stores item factors in two columns: `id` and
@@ -633,6 +635,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable["ALSMo
         """
         return self._call_java("itemFactors")
 
+    @try_remote_attribute_relation
     def recommendForAllUsers(self, numItems: int) -> DataFrame:
         """
         Returns top `numItems` items recommended for each user, for all users.
@@ -652,6 +655,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable["ALSMo
         """
         return self._call_java("recommendForAllUsers", numItems)
 
+    @try_remote_attribute_relation
     def recommendForAllItems(self, numUsers: int) -> DataFrame:
         """
         Returns top `numUsers` users recommended for each item, for all items.
@@ -671,6 +675,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable["ALSMo
         """
         return self._call_java("recommendForAllItems", numUsers)
 
+    @try_remote_attribute_relation
     def recommendForUserSubset(self, dataset: DataFrame, numItems: int) -> 
DataFrame:
         """
         Returns top `numItems` items recommended for each user id in the input 
data set. Note that
@@ -694,6 +699,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable["ALSMo
         """
         return self._call_java("recommendForUserSubset", dataset, numItems)
 
+    @try_remote_attribute_relation
     def recommendForItemSubset(self, dataset: DataFrame, numUsers: int) -> 
DataFrame:
         """
         Returns top `numUsers` users recommended for each item id in the input 
data set. Note that
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/python/pyspark/ml/tests/connect/test_parity_als.py
similarity index 52%
copy from 
mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
copy to python/pyspark/ml/tests/connect/test_parity_als.py
index 37b9c7e6aeb8..e9611900e550 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/python/pyspark/ml/tests/connect/test_parity_als.py
@@ -15,23 +15,23 @@
 # limitations under the License.
 #
 
-# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
estimators.
-# So register the supported estimator here if you're trying to add a new one.
+import unittest
 
-# classification
-org.apache.spark.ml.classification.LogisticRegression
-org.apache.spark.ml.classification.DecisionTreeClassifier
-org.apache.spark.ml.classification.RandomForestClassifier
-org.apache.spark.ml.classification.GBTClassifier
+from pyspark.ml.tests.test_als import ALSTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-# regression
-org.apache.spark.ml.regression.LinearRegression
-org.apache.spark.ml.regression.DecisionTreeRegressor
-org.apache.spark.ml.regression.RandomForestRegressor
-org.apache.spark.ml.regression.GBTRegressor
+class ALSParityTests(ALSTestsMixin, ReusedConnectTestCase):
+    pass
 
 
-# clustering
-org.apache.spark.ml.clustering.KMeans
-org.apache.spark.ml.clustering.BisectingKMeans
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_als import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_als.py 
b/python/pyspark/ml/tests/test_als.py
index 3027b3ab9fd6..bd6cd1cb212f 100644
--- a/python/pyspark/ml/tests/test_als.py
+++ b/python/pyspark/ml/tests/test_als.py
@@ -23,7 +23,73 @@ from pyspark.ml.recommendation import ALS, ALSModel
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
-class ALSTest(ReusedSQLTestCase):
+class ALSTestsMixin:
+    def test_als(self):
+        df = (
+            self.spark.createDataFrame(
+                [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]],
+                ["user", "item", "rating"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("user", "item")
+        )
+
+        als = ALS(
+            userCol="user",
+            itemCol="item",
+            rank=10,
+            seed=1,
+        )
+        als.setMaxIter(2)
+
+        self.assertEqual(als.getUserCol(), "user")
+        self.assertEqual(als.getItemCol(), "item")
+        self.assertEqual(als.getRank(), 10)
+        self.assertEqual(als.getSeed(), 1)
+        self.assertEqual(als.getMaxIter(), 2)
+
+        # Estimator save & load
+        with tempfile.TemporaryDirectory(prefix="ALS") as d:
+            als.write().overwrite().save(d)
+            als2 = ALS.load(d)
+            self.assertEqual(str(als), str(als2))
+
+        model = als.fit(df)
+        self.assertEqual(model.rank, 10)
+
+        self.assertEqual(model.itemFactors.columns, ["id", "features"])
+        self.assertEqual(model.itemFactors.count(), 3)
+
+        self.assertEqual(model.userFactors.columns, ["id", "features"])
+        self.assertEqual(model.userFactors.count(), 2)
+
+        # Transform
+        output = model.transform(df)
+        self.assertEqual(output.columns, ["user", "item", "rating", 
"prediction"])
+        self.assertEqual(output.count(), 4)
+
+        output1 = model.recommendForAllUsers(3)
+        self.assertEqual(output1.columns, ["user", "recommendations"])
+        self.assertEqual(output1.count(), 2)
+
+        output2 = model.recommendForAllItems(3)
+        self.assertEqual(output2.columns, ["item", "recommendations"])
+        self.assertEqual(output2.count(), 3)
+
+        output3 = model.recommendForUserSubset(df, 3)
+        self.assertEqual(output3.columns, ["user", "recommendations"])
+        self.assertEqual(output3.count(), 2)
+
+        output4 = model.recommendForItemSubset(df, 3)
+        self.assertEqual(output4.columns, ["item", "recommendations"])
+        self.assertEqual(output4.count(), 3)
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="als_model") as d:
+            model.write().overwrite().save(d)
+            model2 = ALSModel.load(d)
+            self.assertEqual(str(model), str(model2))
+
     def test_ambiguous_column(self):
         data = self.spark.createDataFrame(
             [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]],
@@ -56,6 +122,10 @@ class ALSTest(ReusedSQLTestCase):
                 self.assertTrue(predictions.count() > 0)
 
 
+class ALSTests(ALSTestsMixin, ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.ml.tests.test_als import *  # noqa: F401
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index e220e69a62c5..000a01c232bd 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -448,7 +448,14 @@ private[ml] object MLUtils {
     "clusterSizes", // KMeansSummary
     "trainingCost", // KMeansSummary
     "cluster", // KMeansSummary
-    "computeCost" // BisectingKMeansModel
+    "computeCost", // BisectingKMeansModel
+    "rank", // ALSModel
+    "itemFactors", // ALSModel
+    "userFactors", // ALSModel
+    "recommendForAllUsers", // ALSModel
+    "recommendForAllItems", // ALSModel
+    "recommendForUserSubset", // ALSModel
+    "recommendForItemSubset" // ALSModel
   )
 
   def invokeMethodAllowed(obj: Object, methodName: String): Object = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to