Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/20058#discussion_r159023958
--- Diff: python/pyspark/ml/base.py ---
@@ -47,6 +86,28 @@ def _fit(self, dataset):
"""
raise NotImplementedError()
+ @since("2.3.0")
+ def fitMultiple(self, dataset, params):
+ """
+ Fits a model to the input dataset for each param map in params.
+
+ :param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.DataFrame`.
+ :param params: A Sequence of param maps.
+ :return: A thread safe iterable which contains one model for each
param map. Each
+ call to `next(modelIterator)` will return `(index,
model)` where model was fit
+ using `params[index]`. Params maps may be fit in an order
different than their
+ order in params.
+
+ .. note:: DeveloperApi
+ .. note:: Experimental
+ """
+ estimator = self.copy()
+
+ def fitSingleModel(index):
+ return estimator.fit(dataset, params[index])
+
+ return FitMultipleIterator(fitSingleModel, len(params))
--- End diff --
The idea is you should be able to do something like this:
```
pool = ...
modelIter = estimator.fitMultiple(params)
rng = range(len(params))
for index, model in pool.imap_unordered(lambda _: next(modelIter), rng):
pass
```
That's pretty much how I've set up corss validator to use it,
https://github.com/apache/spark/pull/20058/files/fe3d6bddc3e9e50febf706d7f22007b1e0d58de3#diff-cbc8c36bfdd245e4e4d5bd27f9b95359R292
The reason for set it up this way is so that, when appropriate, Estimators
can implement their own optimized `fitMultiple` methods that just need to
return an "iterator", A class with `__iter__` and `__next__`. For examples
models that use `maxIter` and `maxDepth` params.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]