This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7db7f3cbc6c3 [SPARK-54574][ML][CONNECT] Reenable FPGrowth on connect
7db7f3cbc6c3 is described below

commit 7db7f3cbc6c31c496160b35ecf2324f82198e7f1
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 3 09:34:34 2025 +0800

    [SPARK-54574][ML][CONNECT] Reenable FPGrowth on connect
    
    ### What changes were proposed in this pull request?
    Reenable FPGrowth on Connect
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, FPGrowth will be available on connect
    
    ### How was this patch tested?
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53294 from zhengruifeng/fpgrowth_model_size.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala   | 11 ++++++++---
 python/pyspark/ml/tests/test_fpm.py                           |  4 +---
 .../scala/org/apache/spark/sql/connect/ml/MLHandler.scala     |  5 -----
 .../src/main/scala/org/apache/spark/sql/classic/Dataset.scala |  2 +-
 4 files changed, 10 insertions(+), 12 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index e25fdc3e05ab..e270294ef2be 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -36,7 +36,7 @@ import 
org.apache.spark.sql.expressions.SparkUserDefinedFunction
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.VersionUtils
+import org.apache.spark.util.{SizeEstimator, VersionUtils}
 
 /**
  * Common params for FPGrowth and FPGrowthModel
@@ -324,8 +324,13 @@ class FPGrowthModel private[ml] (
   }
 
   override def estimatedSize: Long = {
-    // TODO: Implement this method.
-    throw new UnsupportedOperationException
+    freqItemsets match {
+      case df: org.apache.spark.sql.classic.DataFrame =>
+        df.toArrowBatchRdd.map(_.length.toLong).reduce(_ + _) +
+          SizeEstimator.estimate(itemSupport)
+      case o => throw new UnsupportedOperationException(
+        s"Unsupported dataframe type: ${o.getClass.getName}")
+    }
   }
 }
 
diff --git a/python/pyspark/ml/tests/test_fpm.py 
b/python/pyspark/ml/tests/test_fpm.py
index 7b949763c398..ea94216c9860 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,7 +18,7 @@
 import tempfile
 import unittest
 
-from pyspark.sql import is_remote, Row
+from pyspark.sql import Row
 import pyspark.sql.functions as sf
 from pyspark.ml.fpm import (
     FPGrowth,
@@ -30,8 +30,6 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 class FPMTestsMixin:
     def test_fp_growth(self):
-        if is_remote():
-            self.skipTest("Do not support Spark Connect.")
         df = self.spark.createDataFrame(
             [
                 ["r z h k p"],
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 40f1172677a5..3a53aa77fde6 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -229,11 +229,6 @@ private[connect] object MLHandler extends Logging {
           } catch {
             case _: UnsupportedOperationException => ()
           }
-          if (estimator.getClass.getName == 
"org.apache.spark.ml.fpm.FPGrowth") {
-            throw MlUnsupportedException(
-              "FPGrowth algorithm is not supported " +
-                "if Spark Connect model cache offloading is enabled.")
-          }
           if (estimator.getClass.getName == 
"org.apache.spark.ml.clustering.LDA"
             && estimator
               .asInstanceOf[org.apache.spark.ml.clustering.LDA]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index d73918586b09..d02b63b49ca5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -2378,7 +2378,7 @@ class Dataset[T] private[sql](
       sparkSession.sessionState.conf.arrowUseLargeVarTypes)
   }
 
-  private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
+  private[spark] def toArrowBatchRdd: RDD[Array[Byte]] = {
     toArrowBatchRdd(queryExecution.executedPlan)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to