Dong Wang created SPARK-29815:
---------------------------------

             Summary: Missing persist in ml.tuning.CrossValidator.fit()
                 Key: SPARK-29815
                 URL: https://issues.apache.org/jira/browse/SPARK-29815
             Project: Spark
          Issue Type: Improvement
          Components: ML
    Affects Versions: 2.4.3
            Reporter: Dong Wang


dataset.toDF.rdd in ml.tuning.CrossValidator.fit(dataset: Dataset[_]) will 
generate two rdds: training and validation. Some actions will be operated on 
these two rdds, but dataset.toDF.rdd is not persisted, which will cause 
recomputation.

{code:scala}
    // Compute metrics for each model over each split
    val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) // 
dataset.toDF.rdd should be persisted
    val metrics = splits.zipWithIndex.map { case ((training, validation), 
splitIndex) =>
      val trainingDataset = sparkSession.createDataFrame(training, 
schema).cache()
      val validationDataset = sparkSession.createDataFrame(validation, 
schema).cache()
{scala}

This issue is reported by our tool CacheCheck, which is used to dynamically 
detecting persist()/unpersist() api misuses.



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to