This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cdd52963281a [SPARK-51856][ML][CONNECT] Update model size API to count
distributed DataFrame size
cdd52963281a is described below
commit cdd52963281abb62792ba51491a98fa9f87f968a
Author: Weichen Xu <[email protected]>
AuthorDate: Wed Apr 23 08:21:06 2025 +0800
[SPARK-51856][ML][CONNECT] Update model size API to count distributed
DataFrame size
### What changes were proposed in this pull request?
Update model size API to count distributed DataFrame size
### Why are the changes needed?
For Spark server ML cache management.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50652 from WeichenXu123/get-model-ser-size-api.
Lead-authored-by: Weichen Xu <[email protected]>
Co-authored-by: WeichenXu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../src/main/scala/org/apache/spark/ml/Estimator.scala | 4 ++--
mllib/src/main/scala/org/apache/spark/ml/Model.scala | 4 ++--
.../scala/org/apache/spark/ml/clustering/LDA.scala | 5 +++++
.../main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 5 +++++
.../scala/org/apache/spark/ml/recommendation/ALS.scala | 12 ++++++++++++
.../org/apache/spark/ml/recommendation/ALSSuite.scala | 18 ++++++++++++++++++
python/pyspark/ml/tests/test_clustering.py | 3 +++
python/pyspark/ml/tests/test_fpm.py | 4 +++-
8 files changed, 50 insertions(+), 5 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 686afc115436..ead68b290fe4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -87,8 +87,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage
{
* Estimate an upper-bound size of the model to be fitted in bytes, based on
the
* parameters and the dataset, e.g., using $(k) and numFeatures to estimate a
* k-means model size.
- * 1, Only driver side memory usage is counted, distributed objects (like
DataFrame,
- * RDD, Graph, Summary) are ignored.
+ * 1, Both driver side memory usage and distributed objects size (like
DataFrame,
+ * RDD, Graph, Summary) are counted.
* 2, Lazy vals are not counted, e.g., an auxiliary object used in
prediction.
* 3, If there is no enough information to get an accurate size, try to
estimate the
* upper-bound size, e.g.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 7e0297515fa2..6321e5f88f74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -49,8 +49,8 @@ abstract class Model[M <: Model[M]] extends Transformer {
self =>
* For ml connect only.
* Estimate the size of this model in bytes.
* This is an approximation, the real size might be different.
- * 1, Only driver side memory usage is counted, distributed objects (like
DataFrame,
- * RDD, Graph, Summary) are ignored.
+ * 1, Both driver side memory usage and distributed objects size (like
DataFrame,
+ * RDD, Graph, Summary) are counted.
* 2, Lazy vals are not counted, e.g., an auxiliary object used in
prediction.
* 3, The default implementation uses
`org.apache.spark.util.SizeEstimator.estimate`,
* some models override the default implementation to achieve more
precise estimation.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 3ea1c8594e1f..0c5211864385 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -805,6 +805,11 @@ class DistributedLDAModel private[ml] (
override def toString: String = {
s"DistributedLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
}
+
+ override def estimatedSize: Long = {
+ // TODO: Implement this method.
+ throw new UnsupportedOperationException
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 0b75753695fd..7a932d250cee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -322,6 +322,11 @@ class FPGrowthModel private[ml] (
override def toString: String = {
s"FPGrowthModel: uid=$uid, numTrainingRecords=$numTrainingRecords"
}
+
+ override def estimatedSize: Long = {
+ // TODO: Implement this method.
+ throw new UnsupportedOperationException
+ }
}
@Since("2.2.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 95c47531720d..36255d3df0f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -540,6 +540,11 @@ class ALSModel private[ml] (
}
}
+ override def estimatedSize: Long = {
+ val userCount = userFactors.count()
+ val itemCount = itemFactors.count()
+ (userCount + itemCount) * (rank + 1) * 4
+ }
}
@Since("1.6.0")
@@ -771,6 +776,13 @@ class ALS(@Since("1.4.0") override val uid: String)
extends Estimator[ALSModel]
@Since("1.5.0")
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
+
+ override def estimateModelSize(dataset: Dataset[_]): Long = {
+ val userCount = dataset.select(getUserCol).distinct().count()
+ val itemCount = dataset.select(getItemCol).distinct().count()
+ val rank = getRank
+ (userCount + itemCount) * (rank + 1) * 4
+ }
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 94abeaf0804e..4da67a92d707 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -1128,6 +1128,24 @@ class ALSStorageSuite extends SparkFunSuite with
MLlibTestSparkContext with Defa
levels.foreach(level => assert(level == StorageLevel.MEMORY_ONLY))
nonDefaultListener.storageLevels.foreach(level => assert(level ==
StorageLevel.DISK_ONLY))
}
+
+ test("saved model size estimation") {
+ import testImplicits._
+
+ val als = new ALS().setMaxIter(1).setRank(8)
+ val estimatedDFSize = (3 + 2) * (8 + 1) * 4
+ val df = sc.parallelize(Seq(
+ (123, 1, 0.5),
+ (123, 2, 0.7),
+ (123, 3, 0.6),
+ (111, 2, 1.0),
+ (111, 1, 0.1)
+ )).toDF("item", "user", "rating")
+ assert(als.estimateModelSize(df) === estimatedDFSize)
+
+ val model = als.fit(df)
+ assert(model.estimatedSize == estimatedDFSize)
+ }
}
private class IntermediateRDDStorageListener extends SparkListener {
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index a35eaac10a7e..1b8eb73135a9 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -37,6 +37,7 @@ from pyspark.ml.clustering import (
DistributedLDAModel,
PowerIterationClustering,
)
+from pyspark.sql import is_remote
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -377,6 +378,8 @@ class ClusteringTestsMixin:
self.assertEqual(str(model), str(model2))
def test_distributed_lda(self):
+ if is_remote():
+ self.skipTest("Do not support Spark Connect.")
spark = self.spark
df = (
spark.createDataFrame(
diff --git a/python/pyspark/ml/tests/test_fpm.py
b/python/pyspark/ml/tests/test_fpm.py
index ea94216c9860..7b949763c398 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,7 +18,7 @@
import tempfile
import unittest
-from pyspark.sql import Row
+from pyspark.sql import is_remote, Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
@@ -30,6 +30,8 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase
class FPMTestsMixin:
def test_fp_growth(self):
+ if is_remote():
+ self.skipTest("Do not support Spark Connect.")
df = self.spark.createDataFrame(
[
["r z h k p"],
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]