Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/19904#discussion_r156751569
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -146,25 +147,18 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0")
override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation,
schema).cache()
logDebug(s"Train split $splitIndex with multiple sets of
parameters.")
+ val completeFitCount = new AtomicInteger(0)
--- End diff --
I think this might work, but I think what you have in the PR now is better.
```
// Fit models in a Future for training in parallel
var modelFutures = epm.map { paramMap =>
Future[Model[_]] {
val model = est.fit(trainingDataset, paramMap)
model.asInstanceOf[Model[_]]
} (executionContext)
}
// Unpersist training data only when all models have trained
val unitFutures = modelFutures.map{ _.map{ _ => () }
(executionContext) }
Future.sequence[Unit, Iterable](unitFutures)(implicitly,
executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)
// Evaluate models in a Future that will calulate a metric and allow
model to be cleaned up
val foldMetricFutures = modelFutures.zip(epm).map { case
(modelFuture, paramMap) =>
modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset,
paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
modelFutures = null
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]