Github user MrBago commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19904#discussion_r155693144
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
    @@ -146,31 +146,34 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
           val validationDataset = sparkSession.createDataFrame(validation, 
schema).cache()
           logDebug(s"Train split $splitIndex with multiple sets of 
parameters.")
     
    +      var completeFitCount = 0
    +      val signal = new Object
           // Fit models in a Future for training in parallel
    -      val modelFutures = epm.zipWithIndex.map { case (paramMap, 
paramIndex) =>
    -        Future[Model[_]] {
    +      val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, 
paramIndex) =>
    +        Future[Double] {
               val model = est.fit(trainingDataset, 
paramMap).asInstanceOf[Model[_]]
    +          signal.synchronized {
    +            completeFitCount += 1
    +            signal.notify()
    +          }
     
               if (collectSubModelsParam) {
                 subModels.get(splitIndex)(paramIndex) = model
               }
    -          model
    -        } (executionContext)
    -      }
    -
    -      // Unpersist training data only when all models have trained
    -      Future.sequence[Model[_], Iterable](modelFutures)(implicitly, 
executionContext)
    -        .onComplete { _ => trainingDataset.unpersist() } (executionContext)
    -
    -      // Evaluate models in a Future that will calulate a metric and allow 
model to be cleaned up
    -      val foldMetricFutures = modelFutures.zip(epm).map { case 
(modelFuture, paramMap) =>
    -        modelFuture.map { model =>
               // TODO: duplicate evaluator to take extra params from input
               val metric = eval.evaluate(model.transform(validationDataset, 
paramMap))
               logDebug(s"Got metric $metric for model trained with $paramMap.")
               metric
             } (executionContext)
           }
    +      Future {
    +        signal.synchronized {
    +          while (completeFitCount < epm.length) {
    --- End diff --
    
    Sorry I'm not too familiar with Futures in Scala. Is it save to create a 
blocking future like this, do you risk starving the thread pool? Can we just 
just an if statement in the `synchronized` block above? something like:
    
    ```
    completeFitCount += 1
    if (completeFitCount == epm.length) {
        trainingDataset.unpersist()
    }
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to