Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/20095#discussion_r159006471
--- Diff: mllib/src/main/scala/org/apache/spark/ml/Estimator.scala ---
@@ -79,7 +82,51 @@ abstract class Estimator[M <: Model[M]] extends
PipelineStage {
*/
@Since("2.0.0")
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
- paramMaps.map(fit(dataset, _))
+ val modelIter = fitMultiple(dataset, paramMaps)
+ val models = paramMaps.map{ _ =>
+ val (index, model) = modelIter.next()
+ (index.toInt, model)
+ }
+ paramMaps.indices.map(models.toMap)
+}
+
+ /**
+ * Fits multiple models to the input data with multiple sets of
parameters. The default
+ * implementation calls `fit` once for each call to the iterator's
`next` method. Subclasses
+ * could override this to optimize multi-model training.
+ *
+ * @param dataset input dataset
+ * @param paramMaps An array of parameter maps.
+ * These values override any specified in this
Estimator's embedded ParamMap.
+ * @return An iterator which produces one model per call to `next`. The
models may be produced in
+ * a different order than the order of the parameters in
paramMap. The next method of
+ * the iterator will return a tuple of the form `(index, model)`
where model corresponds
+ * to `paramMaps(index)`. This Iterator should be thread safe,
meaning concurrent calls
+ * to `next` should always produce unique values of `index`.
+ *
+ * :: Experimental ::
+ */
+ @Experimental
+ @Since("2.3.0")
+ def fitMultiple(
+ dataset: Dataset[_],
+ paramMaps: Array[ParamMap]): JIterator[(Integer, M)] = {
+
+ val numModel = paramMaps.length
+ val counter = new AtomicInteger(0)
+ new JIterator[(Integer, M)] {
+ def next(): (Integer, M) = {
+ val index = counter.getAndIncrement()
+ if (index < numModel) {
+ (index, fit(dataset, paramMaps(index)))
+ } else {
+ counter.set(numModel)
+ throw new NoSuchElementException("Iterator finished.")
+ }
+ }
+
+ override def hasNext: Boolean = counter.get() < numModel
--- End diff --
This is true for any concurrent class with multiple methods, for example
take `scala.collections.concurrent.Map`. If I were to write `if (!
myMap.contains(key)) { myMap += (key, value) }` I could not guarantee that
`key` was not added between my calls to `contain` & `+=`.
In the multithreaded case folks will need to do something like this:
```
try {
while (true) {
val next = iter.next()
}
} catch (NoSuchElementException)
```
or simply count the number of calls to `next` and ensure that it's equal to
the number of `paramMaps` passes to `fitMultiple` (that's mostly what I do in
this PR).
We need to define `hasNext` to implement the Java Iterator interface and it
could be useful in the single threaded case, but we could drop `hasNext` if
using a Java Iterator doesn't feel important.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]