This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new a747b01c2a11 [SPARK-50898][ML][PYTHON][CONNECT] Support `FPGrowth` on
connect
a747b01c2a11 is described below
commit a747b01c2a11bda7def895b2ee075f2c655b9ebb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 21 16:03:13 2025 +0800
[SPARK-50898][ML][PYTHON][CONNECT] Support `FPGrowth` on connect
### What changes were proposed in this pull request?
Support `FPGrowth` on connect
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
Yes, new algorithms supported on connect
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49579 from zhengruifeng/ml_connect_fpm.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 3ba74bf9b509e1cddbda6bb4849782e26fa840ed)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
.../services/org.apache.spark.ml.Estimator | 4 +
.../services/org.apache.spark.ml.Transformer | 3 +
.../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 2 +
python/pyspark/ml/fpm.py | 4 +-
.../pyspark/ml/tests/connect/test_parity_fpm.py | 30 +++----
python/pyspark/ml/tests/test_fpm.py | 94 ++++++++++++++++++++++
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 4 +-
8 files changed, 124 insertions(+), 19 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index cacd4a83bbe4..5fd3f7377276 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -664,6 +664,7 @@ pyspark_ml = Module(
# unittests
"pyspark.ml.tests.test_algorithms",
"pyspark.ml.tests.test_als",
+ "pyspark.ml.tests.test_fpm",
"pyspark.ml.tests.test_base",
"pyspark.ml.tests.test_evaluation",
"pyspark.ml.tests.test_feature",
@@ -1119,6 +1120,7 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_parity_als",
+ "pyspark.ml.tests.connect.test_parity_fpm",
"pyspark.ml.tests.connect.test_parity_classification",
"pyspark.ml.tests.connect.test_parity_regression",
"pyspark.ml.tests.connect.test_parity_clustering",
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index a7d7d3da9df3..4046cca07dc0 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -39,3 +39,7 @@ org.apache.spark.ml.clustering.BisectingKMeans
# recommendation
org.apache.spark.ml.recommendation.ALS
+
+
+# fpm
+org.apache.spark.ml.fpm.FPGrowth
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 392115be98ba..7c10796f9a87 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -38,3 +38,6 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
# recommendation
org.apache.spark.ml.recommendation.ALSModel
+
+# fpm
+org.apache.spark.ml.fpm.FPGrowthModel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index d054ea8ebdb4..d90124c62d54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -223,6 +223,8 @@ class FPGrowthModel private[ml] (
private val numTrainingRecords: Long)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {
+ private[ml] def this() = this(Identifiable.randomUID("fpgrowth"), null,
Map.empty, 0L)
+
/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 72fcfccf19e4..c068b5f26ba8 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -20,7 +20,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
from pyspark import keyword_only, since
from pyspark.sql import DataFrame
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable
+from pyspark.ml.util import JavaMLWritable, JavaMLReadable,
try_remote_attribute_relation
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters,
Params
@@ -126,6 +126,7 @@ class FPGrowthModel(JavaModel, _FPGrowthParams,
JavaMLWritable, JavaMLReadable["
@property
@since("2.2.0")
+ @try_remote_attribute_relation
def freqItemsets(self) -> DataFrame:
"""
DataFrame with two columns:
@@ -136,6 +137,7 @@ class FPGrowthModel(JavaModel, _FPGrowthParams,
JavaMLWritable, JavaMLReadable["
@property
@since("2.2.0")
+ @try_remote_attribute_relation
def associationRules(self) -> DataFrame:
"""
DataFrame with four columns:
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/python/pyspark/ml/tests/connect/test_parity_fpm.py
similarity index 50%
copy from
mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
copy to python/pyspark/ml/tests/connect/test_parity_fpm.py
index a7d7d3da9df3..85ceba87a2f5 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/python/pyspark/ml/tests/connect/test_parity_fpm.py
@@ -15,27 +15,23 @@
# limitations under the License.
#
-# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
estimators.
-# So register the supported estimator here if you're trying to add a new one.
+import unittest
-# classification
-org.apache.spark.ml.classification.LogisticRegression
-org.apache.spark.ml.classification.DecisionTreeClassifier
-org.apache.spark.ml.classification.RandomForestClassifier
-org.apache.spark.ml.classification.GBTClassifier
+from pyspark.ml.tests.test_fpm import FPMTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
-# regression
-org.apache.spark.ml.regression.LinearRegression
-org.apache.spark.ml.regression.DecisionTreeRegressor
-org.apache.spark.ml.regression.RandomForestRegressor
-org.apache.spark.ml.regression.GBTRegressor
+class FPMParityTests(FPMTestsMixin, ReusedConnectTestCase):
+ pass
-# clustering
-org.apache.spark.ml.clustering.KMeans
-org.apache.spark.ml.clustering.BisectingKMeans
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_fpm import * # noqa: F401
+ try:
+ import xmlrunner # type: ignore[import]
-# recommendation
-org.apache.spark.ml.recommendation.ALS
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_fpm.py
b/python/pyspark/ml/tests/test_fpm.py
new file mode 100644
index 000000000000..8db35158978d
--- /dev/null
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+from pyspark.sql import SparkSession
+import pyspark.sql.functions as sf
+from pyspark.ml.fpm import (
+ FPGrowth,
+ FPGrowthModel,
+)
+
+
+class FPMTestsMixin:
+ def test_fp_growth(self):
+ df = self.spark.createDataFrame(
+ [
+ ["r z h k p"],
+ ["z y x w v u t s"],
+ ["s x o n r"],
+ ["x z y m t s q e"],
+ ["z"],
+ ["x z y r q t p"],
+ ],
+ ["items"],
+ ).select(sf.split("items", " ").alias("items"))
+
+ fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
+ fp.setNumPartitions(1)
+ self.assertEqual(fp.getMinSupport(), 0.2)
+ self.assertEqual(fp.getMinConfidence(), 0.7)
+ self.assertEqual(fp.getNumPartitions(), 1)
+
+ # Estimator save & load
+ with tempfile.TemporaryDirectory(prefix="fp_growth") as d:
+ fp.write().overwrite().save(d)
+ fp2 = FPGrowth.load(d)
+ self.assertEqual(str(fp), str(fp2))
+
+ model = fp.fit(df)
+
+ self.assertEqual(model.freqItemsets.columns, ["items", "freq"])
+ self.assertEqual(model.freqItemsets.count(), 54)
+
+ self.assertEqual(
+ model.associationRules.columns,
+ ["antecedent", "consequent", "confidence", "lift", "support"],
+ )
+ self.assertEqual(model.associationRules.count(), 89)
+
+ output = model.transform(df)
+ self.assertEqual(output.columns, ["items", "prediction"])
+ self.assertEqual(output.count(), 6)
+
+ # Model save & load
+ with tempfile.TemporaryDirectory(prefix="fp_growth_model") as d:
+ model.write().overwrite().save(d)
+ model2 = FPGrowthModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
+
+class FPMTests(FPMTestsMixin, unittest.TestCase):
+ def setUp(self) -> None:
+ self.spark = SparkSession.builder.master("local[4]").getOrCreate()
+
+ def tearDown(self) -> None:
+ self.spark.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_fpm import * # noqa: F401,F403
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 4e93aec47ef0..b85bc6771f8e 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -500,7 +500,9 @@ private[ml] object MLUtils {
"recommendForAllUsers", // ALSModel
"recommendForAllItems", // ALSModel
"recommendForUserSubset", // ALSModel
- "recommendForItemSubset" // ALSModel
+ "recommendForItemSubset", // ALSModel
+ "associationRules", // FPGrowthModel
+ "freqItemsets" // FPGrowthModel
)
def invokeMethodAllowed(obj: Object, methodName: String): Object = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]