Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20095#discussion_r158931079
--- Diff: mllib/src/main/scala/org/apache/spark/ml/Estimator.scala ---
@@ -79,7 +82,51 @@ abstract class Estimator[M <: Model[M]] extends
PipelineStage {
*/
@Since("2.0.0")
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
- paramMaps.map(fit(dataset, _))
+ val modelIter = fitMultiple(dataset, paramMaps)
+ val models = paramMaps.map{ _ =>
+ val (index, model) = modelIter.next()
+ (index.toInt, model)
+ }
+ paramMaps.indices.map(models.toMap)
+}
+
+ /**
+ * Fits multiple models to the input data with multiple sets of
parameters. The default
+ * implementation calls `fit` once for each call to the iterator's
`next` method. Subclasses
+ * could override this to optimize multi-model training.
+ *
+ * @param dataset input dataset
+ * @param paramMaps An array of parameter maps.
+ * These values override any specified in this
Estimator's embedded ParamMap.
+ * @return An iterator which produces one model per call to `next`. The
models may be produced in
+ * a different order than the order of the parameters in
paramMap. The next method of
+ * the iterator will return a tuple of the form `(index, model)`
where model corresponds
+ * to `paramMaps(index)`. This Iterator should be thread safe,
meaning concurrent calls
+ * to `next` should always produce unique values of `index`.
+ *
+ * :: Experimental ::
+ */
+ @Experimental
+ @Since("2.3.0")
+ def fitMultiple(
+ dataset: Dataset[_],
+ paramMaps: Array[ParamMap]): JIterator[(Integer, M)] = {
+
+ val numModel = paramMaps.length
+ val counter = new AtomicInteger(0)
+ new JIterator[(Integer, M)] {
+ def next(): (Integer, M) = {
+ val index = counter.getAndIncrement()
+ if (index < numModel) {
+ (index, fit(dataset, paramMaps(index)))
+ } else {
+ counter.set(numModel)
--- End diff --
This `counter.set` seems to be meaningless.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]