This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new eb97f95 [SPARK-28154][ML][FOLLOWUP] GMM fix double caching
eb97f95 is described below
commit eb97f952f6c54af4f192c7bf85a09612000ff793
Author: zhengruifeng <[email protected]>
AuthorDate: Tue Jun 25 06:50:34 2019 -0500
[SPARK-28154][ML][FOLLOWUP] GMM fix double caching
if the input dataset is alreadly cached, then we do not need to cache the
internal rdd (like kmeans)
existing test
Closes #24919 from zhengruifeng/gmm_fix_double_caching.
Authored-by: zhengruifeng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
(cherry picked from commit c83b3ddb56d4a32158676b042a7eae861689e141)
Signed-off-by: Sean Owen <[email protected]>
---
.../org/apache/spark/ml/clustering/GaussianMixture.scala | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88abc16..6c0b49c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.storage.StorageLevel
/**
@@ -342,10 +343,15 @@ class GaussianMixture @Since("2.0.0") (
val sc = dataset.sparkSession.sparkContext
val numClusters = $(k)
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
case Row(features: Vector) => features
- }.cache()
+ }
+
+ if (handlePersistence) {
+ instances.persist(StorageLevel.MEMORY_AND_DISK)
+ }
// Extract the number of features.
val numFeatures = instances.first().size
@@ -422,8 +428,10 @@ class GaussianMixture @Since("2.0.0") (
logLikelihood = sums.logLikelihood // this is the freshly computed
log-likelihood
iter += 1
}
+ if (handlePersistence) {
+ instances.unpersist()
+ }
- instances.unpersist(false)
val gaussianDists = gaussians.map { case (mean, covVec) =>
val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures,
covVec.values)
new MultivariateGaussian(mean, cov)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]