Github user BryanCutler commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16774#discussion_r137052323
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
    @@ -100,31 +113,53 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
         val eval = $(evaluator)
         val epm = $(estimatorParamMaps)
         val numModels = epm.length
    -    val metrics = new Array[Double](epm.length)
    +
    +    // Create execution context based on $(parallelism)
    +    val executionContext = getExecutionContext
     
         val instr = Instrumentation.create(this, dataset)
    -    instr.logParams(numFolds, seed)
    +    instr.logParams(numFolds, seed, parallelism)
         logTuningParams(instr)
     
    +    // Compute metrics for each model over each split
         val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
    -    splits.zipWithIndex.foreach { case ((training, validation), 
splitIndex) =>
    +    val metrics = splits.zipWithIndex.map { case ((training, validation), 
splitIndex) =>
           val trainingDataset = sparkSession.createDataFrame(training, 
schema).cache()
           val validationDataset = sparkSession.createDataFrame(validation, 
schema).cache()
    -      // multi-model training
           logDebug(s"Train split $splitIndex with multiple sets of 
parameters.")
    -      val models = est.fit(trainingDataset, 
epm).asInstanceOf[Seq[Model[_]]]
    -      trainingDataset.unpersist()
    -      var i = 0
    -      while (i < numModels) {
    -        // TODO: duplicate evaluator to take extra params from input
    -        val metric = eval.evaluate(models(i).transform(validationDataset, 
epm(i)))
    -        logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
    -        metrics(i) += metric
    -        i += 1
    +
    +      // Fit models in a Future for training in parallel
    +      val models = epm.map { paramMap =>
    +        Future[Model[_]] {
    +          val model = est.fit(trainingDataset, paramMap)
    +          model.asInstanceOf[Model[_]]
    +        } (executionContext)
           }
    +
    +      // Unpersist training data only when all models have trained
    +      Future.sequence[Model[_], Iterable](models)(implicitly, 
executionContext).onComplete { _ =>
    +        trainingDataset.unpersist()
    +      } (executionContext)
    +
    +      // Evaluate models in a Future that will calulate a metric and allow 
model to be cleaned up
    +      val foldMetricFutures = models.zip(epm).map { case (modelFuture, 
paramMap) =>
    +        modelFuture.map { model =>
    +          // TODO: duplicate evaluator to take extra params from input
    +          val metric = eval.evaluate(model.transform(validationDataset, 
paramMap))
    +          logDebug(s"Got metric $metric for model trained with $paramMap.")
    +          metric
    +        } (executionContext)
    +      }
    +
    +      // Wait for metrics to be calculated before unpersisting validation 
dataset
    +      val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, 
Duration.Inf))
           validationDataset.unpersist()
    -    }
    +      foldMetrics
    +    }.transpose.map(_.sum)
    +
    +    // Calculate average metric over all splits
         f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
    --- End diff --
    
    Yeah, I agree.  I'll go ahead and make that change here.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to