Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/16774#discussion_r137052323
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -100,31 +113,53 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0")
override val uid: String)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val numModels = epm.length
- val metrics = new Array[Double](epm.length)
+
+ // Create execution context based on $(parallelism)
+ val executionContext = getExecutionContext
val instr = Instrumentation.create(this, dataset)
- instr.logParams(numFolds, seed)
+ instr.logParams(numFolds, seed, parallelism)
logTuningParams(instr)
+ // Compute metrics for each model over each split
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
- splits.zipWithIndex.foreach { case ((training, validation),
splitIndex) =>
+ val metrics = splits.zipWithIndex.map { case ((training, validation),
splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training,
schema).cache()
val validationDataset = sparkSession.createDataFrame(validation,
schema).cache()
- // multi-model training
logDebug(s"Train split $splitIndex with multiple sets of
parameters.")
- val models = est.fit(trainingDataset,
epm).asInstanceOf[Seq[Model[_]]]
- trainingDataset.unpersist()
- var i = 0
- while (i < numModels) {
- // TODO: duplicate evaluator to take extra params from input
- val metric = eval.evaluate(models(i).transform(validationDataset,
epm(i)))
- logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
- metrics(i) += metric
- i += 1
+
+ // Fit models in a Future for training in parallel
+ val models = epm.map { paramMap =>
+ Future[Model[_]] {
+ val model = est.fit(trainingDataset, paramMap)
+ model.asInstanceOf[Model[_]]
+ } (executionContext)
}
+
+ // Unpersist training data only when all models have trained
+ Future.sequence[Model[_], Iterable](models)(implicitly,
executionContext).onComplete { _ =>
+ trainingDataset.unpersist()
+ } (executionContext)
+
+ // Evaluate models in a Future that will calulate a metric and allow
model to be cleaned up
+ val foldMetricFutures = models.zip(epm).map { case (modelFuture,
paramMap) =>
+ modelFuture.map { model =>
+ // TODO: duplicate evaluator to take extra params from input
+ val metric = eval.evaluate(model.transform(validationDataset,
paramMap))
+ logDebug(s"Got metric $metric for model trained with $paramMap.")
+ metric
+ } (executionContext)
+ }
+
+ // Wait for metrics to be calculated before unpersisting validation
dataset
+ val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_,
Duration.Inf))
validationDataset.unpersist()
- }
+ foldMetrics
+ }.transpose.map(_.sum)
+
+ // Calculate average metric over all splits
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
--- End diff --
Yeah, I agree. I'll go ahead and make that change here.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]