Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20095#discussion_r159127966
--- Diff: mllib/src/main/scala/org/apache/spark/ml/Estimator.scala ---
@@ -79,7 +82,52 @@ abstract class Estimator[M <: Model[M]] extends
PipelineStage {
*/
@Since("2.0.0")
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
- paramMaps.map(fit(dataset, _))
+ val modelIter = fitMultiple(dataset, paramMaps)
+ val models = paramMaps.map { _ =>
+ val (index, model) = modelIter.next()
+ (index.toInt, model)
+ }
+ paramMaps.indices.map(models.toMap)
+}
--- End diff --
style: indentation
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]