This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7db7f3cbc6c3 [SPARK-54574][ML][CONNECT] Reenable FPGrowth on connect
7db7f3cbc6c3 is described below
commit 7db7f3cbc6c31c496160b35ecf2324f82198e7f1
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 3 09:34:34 2025 +0800
[SPARK-54574][ML][CONNECT] Reenable FPGrowth on connect
### What changes were proposed in this pull request?
Reenable FPGrowth on Connect
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
yes, FPGrowth will be available on connect
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53294 from zhengruifeng/fpgrowth_model_size.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 11 ++++++++---
python/pyspark/ml/tests/test_fpm.py | 4 +---
.../scala/org/apache/spark/sql/connect/ml/MLHandler.scala | 5 -----
.../src/main/scala/org/apache/spark/sql/classic/Dataset.scala | 2 +-
4 files changed, 10 insertions(+), 12 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index e25fdc3e05ab..e270294ef2be 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -36,7 +36,7 @@ import
org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.VersionUtils
+import org.apache.spark.util.{SizeEstimator, VersionUtils}
/**
* Common params for FPGrowth and FPGrowthModel
@@ -324,8 +324,13 @@ class FPGrowthModel private[ml] (
}
override def estimatedSize: Long = {
- // TODO: Implement this method.
- throw new UnsupportedOperationException
+ freqItemsets match {
+ case df: org.apache.spark.sql.classic.DataFrame =>
+ df.toArrowBatchRdd.map(_.length.toLong).reduce(_ + _) +
+ SizeEstimator.estimate(itemSupport)
+ case o => throw new UnsupportedOperationException(
+ s"Unsupported dataframe type: ${o.getClass.getName}")
+ }
}
}
diff --git a/python/pyspark/ml/tests/test_fpm.py
b/python/pyspark/ml/tests/test_fpm.py
index 7b949763c398..ea94216c9860 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,7 +18,7 @@
import tempfile
import unittest
-from pyspark.sql import is_remote, Row
+from pyspark.sql import Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
@@ -30,8 +30,6 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase
class FPMTestsMixin:
def test_fp_growth(self):
- if is_remote():
- self.skipTest("Do not support Spark Connect.")
df = self.spark.createDataFrame(
[
["r z h k p"],
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 40f1172677a5..3a53aa77fde6 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -229,11 +229,6 @@ private[connect] object MLHandler extends Logging {
} catch {
case _: UnsupportedOperationException => ()
}
- if (estimator.getClass.getName ==
"org.apache.spark.ml.fpm.FPGrowth") {
- throw MlUnsupportedException(
- "FPGrowth algorithm is not supported " +
- "if Spark Connect model cache offloading is enabled.")
- }
if (estimator.getClass.getName ==
"org.apache.spark.ml.clustering.LDA"
&& estimator
.asInstanceOf[org.apache.spark.ml.clustering.LDA]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index d73918586b09..d02b63b49ca5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -2378,7 +2378,7 @@ class Dataset[T] private[sql](
sparkSession.sessionState.conf.arrowUseLargeVarTypes)
}
- private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
+ private[spark] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]