This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new eb97f95  [SPARK-28154][ML][FOLLOWUP] GMM fix double caching
eb97f95 is described below

commit eb97f952f6c54af4f192c7bf85a09612000ff793
Author: zhengruifeng <[email protected]>
AuthorDate: Tue Jun 25 06:50:34 2019 -0500

    [SPARK-28154][ML][FOLLOWUP] GMM fix double caching
    
    if the input dataset is alreadly cached, then we do not need to cache the 
internal rdd (like kmeans)
    
    existing test
    
    Closes #24919 from zhengruifeng/gmm_fix_double_caching.
    
    Authored-by: zhengruifeng <[email protected]>
    Signed-off-by: Sean Owen <[email protected]>
    (cherry picked from commit c83b3ddb56d4a32158676b042a7eae861689e141)
    Signed-off-by: Sean Owen <[email protected]>
---
 .../org/apache/spark/ml/clustering/GaussianMixture.scala     | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88abc16..6c0b49c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.storage.StorageLevel
 
 
 /**
@@ -342,10 +343,15 @@ class GaussianMixture @Since("2.0.0") (
     val sc = dataset.sparkSession.sparkContext
     val numClusters = $(k)
 
+    val handlePersistence = dataset.storageLevel == StorageLevel.NONE
     val instances = dataset
       .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
       case Row(features: Vector) => features
-    }.cache()
+    }
+
+    if (handlePersistence) {
+      instances.persist(StorageLevel.MEMORY_AND_DISK)
+    }
 
     // Extract the number of features.
     val numFeatures = instances.first().size
@@ -422,8 +428,10 @@ class GaussianMixture @Since("2.0.0") (
       logLikelihood = sums.logLikelihood  // this is the freshly computed 
log-likelihood
       iter += 1
     }
+    if (handlePersistence) {
+      instances.unpersist()
+    }
 
-    instances.unpersist(false)
     val gaussianDists = gaussians.map { case (mean, covVec) =>
       val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, 
covVec.values)
       new MultivariateGaussian(mean, cov)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to