Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/19904#discussion_r155693144
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -146,31 +146,34 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0")
override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation,
schema).cache()
logDebug(s"Train split $splitIndex with multiple sets of
parameters.")
+ var completeFitCount = 0
+ val signal = new Object
// Fit models in a Future for training in parallel
- val modelFutures = epm.zipWithIndex.map { case (paramMap,
paramIndex) =>
- Future[Model[_]] {
+ val foldMetricFutures = epm.zipWithIndex.map { case (paramMap,
paramIndex) =>
+ Future[Double] {
val model = est.fit(trainingDataset,
paramMap).asInstanceOf[Model[_]]
+ signal.synchronized {
+ completeFitCount += 1
+ signal.notify()
+ }
if (collectSubModelsParam) {
subModels.get(splitIndex)(paramIndex) = model
}
- model
- } (executionContext)
- }
-
- // Unpersist training data only when all models have trained
- Future.sequence[Model[_], Iterable](modelFutures)(implicitly,
executionContext)
- .onComplete { _ => trainingDataset.unpersist() } (executionContext)
-
- // Evaluate models in a Future that will calulate a metric and allow
model to be cleaned up
- val foldMetricFutures = modelFutures.zip(epm).map { case
(modelFuture, paramMap) =>
- modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset,
paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
+ Future {
+ signal.synchronized {
+ while (completeFitCount < epm.length) {
--- End diff --
Sorry I'm not too familiar with Futures in Scala. Is it save to create a
blocking future like this, do you risk starving the thread pool? Can we just
just an if statement in the `synchronized` block above? something like:
```
completeFitCount += 1
if (completeFitCount == epm.length) {
trainingDataset.unpersist()
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]