This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 28bbb5481c59 [SPARK-50878][ML][PYTHON][CONNECT] Support ALS on Connect
28bbb5481c59 is described below
commit 28bbb5481c5986931a480f4e1efafa864648aed6
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 20 20:52:39 2025 +0800
[SPARK-50878][ML][PYTHON][CONNECT] Support ALS on Connect
### What changes were proposed in this pull request?
Support ALS on Connect
### Why are the changes needed?
For feature parity
### Does this PR introduce _any_ user-facing change?
yes, ALS supported
### How was this patch tested?
Added tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49570 from zhengruifeng/ml_connect_als.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 78e4b79a15953e3d825e29c4cf54cb5cd4c48a12)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
.../services/org.apache.spark.ml.Estimator | 4 ++
python/pyspark/ml/recommendation.py | 8 ++-
.../pyspark/ml/tests/connect/test_parity_als.py | 30 ++++-----
python/pyspark/ml/tests/test_als.py | 72 +++++++++++++++++++++-
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 9 ++-
6 files changed, 106 insertions(+), 18 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f34a33dd4b69..cacd4a83bbe4 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1118,6 +1118,7 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_connect_classification",
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
+ "pyspark.ml.tests.connect.test_parity_als",
"pyspark.ml.tests.connect.test_parity_classification",
"pyspark.ml.tests.connect.test_parity_regression",
"pyspark.ml.tests.connect.test_parity_clustering",
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index 37b9c7e6aeb8..a7d7d3da9df3 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -35,3 +35,7 @@ org.apache.spark.ml.regression.GBTRegressor
# clustering
org.apache.spark.ml.clustering.KMeans
org.apache.spark.ml.clustering.BisectingKMeans
+
+
+# recommendation
+org.apache.spark.ml.recommendation.ALS
diff --git a/python/pyspark/ml/recommendation.py
b/python/pyspark/ml/recommendation.py
index 873140e51afb..d11990634593 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -30,7 +30,7 @@ from pyspark.ml.param.shared import (
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.common import inherit_doc
from pyspark.ml.param import Params, TypeConverters, Param
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable
+from pyspark.ml.util import JavaMLWritable, JavaMLReadable,
try_remote_attribute_relation
from pyspark.sql import DataFrame
if TYPE_CHECKING:
@@ -617,6 +617,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable["ALSMo
@property
@since("1.4.0")
+ @try_remote_attribute_relation
def userFactors(self) -> DataFrame:
"""
a DataFrame that stores user factors in two columns: `id` and
@@ -626,6 +627,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable["ALSMo
@property
@since("1.4.0")
+ @try_remote_attribute_relation
def itemFactors(self) -> DataFrame:
"""
a DataFrame that stores item factors in two columns: `id` and
@@ -633,6 +635,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable["ALSMo
"""
return self._call_java("itemFactors")
+ @try_remote_attribute_relation
def recommendForAllUsers(self, numItems: int) -> DataFrame:
"""
Returns top `numItems` items recommended for each user, for all users.
@@ -652,6 +655,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable["ALSMo
"""
return self._call_java("recommendForAllUsers", numItems)
+ @try_remote_attribute_relation
def recommendForAllItems(self, numUsers: int) -> DataFrame:
"""
Returns top `numUsers` users recommended for each item, for all items.
@@ -671,6 +675,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable["ALSMo
"""
return self._call_java("recommendForAllItems", numUsers)
+ @try_remote_attribute_relation
def recommendForUserSubset(self, dataset: DataFrame, numItems: int) ->
DataFrame:
"""
Returns top `numItems` items recommended for each user id in the input
data set. Note that
@@ -694,6 +699,7 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable["ALSMo
"""
return self._call_java("recommendForUserSubset", dataset, numItems)
+ @try_remote_attribute_relation
def recommendForItemSubset(self, dataset: DataFrame, numUsers: int) ->
DataFrame:
"""
Returns top `numUsers` users recommended for each item id in the input
data set. Note that
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/python/pyspark/ml/tests/connect/test_parity_als.py
similarity index 52%
copy from
mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
copy to python/pyspark/ml/tests/connect/test_parity_als.py
index 37b9c7e6aeb8..e9611900e550 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/python/pyspark/ml/tests/connect/test_parity_als.py
@@ -15,23 +15,23 @@
# limitations under the License.
#
-# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
estimators.
-# So register the supported estimator here if you're trying to add a new one.
+import unittest
-# classification
-org.apache.spark.ml.classification.LogisticRegression
-org.apache.spark.ml.classification.DecisionTreeClassifier
-org.apache.spark.ml.classification.RandomForestClassifier
-org.apache.spark.ml.classification.GBTClassifier
+from pyspark.ml.tests.test_als import ALSTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
-# regression
-org.apache.spark.ml.regression.LinearRegression
-org.apache.spark.ml.regression.DecisionTreeRegressor
-org.apache.spark.ml.regression.RandomForestRegressor
-org.apache.spark.ml.regression.GBTRegressor
+class ALSParityTests(ALSTestsMixin, ReusedConnectTestCase):
+ pass
-# clustering
-org.apache.spark.ml.clustering.KMeans
-org.apache.spark.ml.clustering.BisectingKMeans
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_als import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_als.py
b/python/pyspark/ml/tests/test_als.py
index 3027b3ab9fd6..bd6cd1cb212f 100644
--- a/python/pyspark/ml/tests/test_als.py
+++ b/python/pyspark/ml/tests/test_als.py
@@ -23,7 +23,73 @@ from pyspark.ml.recommendation import ALS, ALSModel
from pyspark.testing.sqlutils import ReusedSQLTestCase
-class ALSTest(ReusedSQLTestCase):
+class ALSTestsMixin:
+ def test_als(self):
+ df = (
+ self.spark.createDataFrame(
+ [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]],
+ ["user", "item", "rating"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("user", "item")
+ )
+
+ als = ALS(
+ userCol="user",
+ itemCol="item",
+ rank=10,
+ seed=1,
+ )
+ als.setMaxIter(2)
+
+ self.assertEqual(als.getUserCol(), "user")
+ self.assertEqual(als.getItemCol(), "item")
+ self.assertEqual(als.getRank(), 10)
+ self.assertEqual(als.getSeed(), 1)
+ self.assertEqual(als.getMaxIter(), 2)
+
+ # Estimator save & load
+ with tempfile.TemporaryDirectory(prefix="ALS") as d:
+ als.write().overwrite().save(d)
+ als2 = ALS.load(d)
+ self.assertEqual(str(als), str(als2))
+
+ model = als.fit(df)
+ self.assertEqual(model.rank, 10)
+
+ self.assertEqual(model.itemFactors.columns, ["id", "features"])
+ self.assertEqual(model.itemFactors.count(), 3)
+
+ self.assertEqual(model.userFactors.columns, ["id", "features"])
+ self.assertEqual(model.userFactors.count(), 2)
+
+ # Transform
+ output = model.transform(df)
+ self.assertEqual(output.columns, ["user", "item", "rating",
"prediction"])
+ self.assertEqual(output.count(), 4)
+
+ output1 = model.recommendForAllUsers(3)
+ self.assertEqual(output1.columns, ["user", "recommendations"])
+ self.assertEqual(output1.count(), 2)
+
+ output2 = model.recommendForAllItems(3)
+ self.assertEqual(output2.columns, ["item", "recommendations"])
+ self.assertEqual(output2.count(), 3)
+
+ output3 = model.recommendForUserSubset(df, 3)
+ self.assertEqual(output3.columns, ["user", "recommendations"])
+ self.assertEqual(output3.count(), 2)
+
+ output4 = model.recommendForItemSubset(df, 3)
+ self.assertEqual(output4.columns, ["item", "recommendations"])
+ self.assertEqual(output4.count(), 3)
+
+ # Model save & load
+ with tempfile.TemporaryDirectory(prefix="als_model") as d:
+ model.write().overwrite().save(d)
+ model2 = ALSModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
def test_ambiguous_column(self):
data = self.spark.createDataFrame(
[[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]],
@@ -56,6 +122,10 @@ class ALSTest(ReusedSQLTestCase):
self.assertTrue(predictions.count() > 0)
+class ALSTests(ALSTestsMixin, ReusedSQLTestCase):
+ pass
+
+
if __name__ == "__main__":
from pyspark.ml.tests.test_als import * # noqa: F401
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index e220e69a62c5..000a01c232bd 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -448,7 +448,14 @@ private[ml] object MLUtils {
"clusterSizes", // KMeansSummary
"trainingCost", // KMeansSummary
"cluster", // KMeansSummary
- "computeCost" // BisectingKMeansModel
+ "computeCost", // BisectingKMeansModel
+ "rank", // ALSModel
+ "itemFactors", // ALSModel
+ "userFactors", // ALSModel
+ "recommendForAllUsers", // ALSModel
+ "recommendForAllItems", // ALSModel
+ "recommendForUserSubset", // ALSModel
+ "recommendForItemSubset" // ALSModel
)
def invokeMethodAllowed(obj: Object, methodName: String): Object = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]