Github user MLnick commented on a diff in the pull request:
https://github.com/apache/spark/pull/16774#discussion_r110614175
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---
@@ -87,37 +90,64 @@ class TrainValidationSplit @Since("1.5.0")
(@Since("1.5.0") override val uid: St
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /** @group setParam */
+ @Since("2.2.0")
+ def setNumParallelEval(value: Int): this.type = set(numParallelEval,
value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
- val numModels = epm.length
- val metrics = new Array[Double](epm.length)
+
+ // Create execution context, run in serial if numParallelEval is 1
+ val executionContext = $(numParallelEval) match {
+ case 1 =>
+ ThreadUtils.sameThread
+ case n =>
+ ExecutionContext.fromExecutorService(executorServiceFactory(n))
+ }
val instr = Instrumentation.create(this, dataset)
instr.logParams(trainRatio, seed)
logTuningParams(instr)
+ logDebug(s"Running validation with level of parallelism:
$numParallelEval.")
val Array(trainingDataset, validationDataset) =
dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
trainingDataset.cache()
validationDataset.cache()
- // multi-model training
+ // Fit models in a Future with thread-pool size determined by
'$numParallelEval'
logDebug(s"Train split with multiple sets of parameters.")
- val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
- trainingDataset.unpersist()
- var i = 0
- while (i < numModels) {
- // TODO: duplicate evaluator to take extra params from input
- val metric = eval.evaluate(models(i).transform(validationDataset,
epm(i)))
- logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
- metrics(i) += metric
- i += 1
+ val models = epm.map { paramMap =>
+ Future[Model[_]] {
+ val model = est.fit(trainingDataset, paramMap)
+ model.asInstanceOf[Model[_]]
+ } (executionContext)
}
+
+ Future.sequence[Model[_], Iterable](models)(implicitly,
executionContext).onComplete { _ =>
+ trainingDataset.unpersist()
+ } (executionContext)
+
+ // Evaluate models concurrently, limited by a barrier with
'$numParallelEval' permits
--- End diff --
This comment is stale
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]