Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20095#discussion_r158930992
--- Diff: mllib/src/main/scala/org/apache/spark/ml/Estimator.scala ---
@@ -79,7 +82,51 @@ abstract class Estimator[M <: Model[M]] extends
PipelineStage {
*/
@Since("2.0.0")
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
- paramMaps.map(fit(dataset, _))
+ val modelIter = fitMultiple(dataset, paramMaps)
+ val models = paramMaps.map{ _ =>
+ val (index, model) = modelIter.next()
+ (index.toInt, model)
+ }
+ paramMaps.indices.map(models.toMap)
+}
+
+ /**
+ * Fits multiple models to the input data with multiple sets of
parameters. The default
+ * implementation calls `fit` once for each call to the iterator's
`next` method. Subclasses
+ * could override this to optimize multi-model training.
+ *
+ * @param dataset input dataset
+ * @param paramMaps An array of parameter maps.
+ * These values override any specified in this
Estimator's embedded ParamMap.
+ * @return An iterator which produces one model per call to `next`. The
models may be produced in
+ * a different order than the order of the parameters in
paramMap. The next method of
+ * the iterator will return a tuple of the form `(index, model)`
where model corresponds
+ * to `paramMaps(index)`. This Iterator should be thread safe,
meaning concurrent calls
+ * to `next` should always produce unique values of `index`.
+ *
+ * :: Experimental ::
+ */
+ @Experimental
+ @Since("2.3.0")
+ def fitMultiple(
+ dataset: Dataset[_],
+ paramMaps: Array[ParamMap]): JIterator[(Integer, M)] = {
+
+ val numModel = paramMaps.length
+ val counter = new AtomicInteger(0)
+ new JIterator[(Integer, M)] {
+ def next(): (Integer, M) = {
+ val index = counter.getAndIncrement()
+ if (index < numModel) {
+ (index, fit(dataset, paramMaps(index)))
+ } else {
+ counter.set(numModel)
+ throw new NoSuchElementException("Iterator finished.")
+ }
+ }
+
+ override def hasNext: Boolean = counter.get() < numModel
--- End diff --
Suppose we have 2 threads, and at the time the iterator remaining only one
element, the 2 threads both call `hasNext` first, and all passed, and then they
both call `next`, then one of the thread will throw exception.
Is this the expected activity ? `hasNext` return `true` but `next` is
possible to fail. It seems to break the iterator API ?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]