Repository: spark
Updated Branches:
  refs/heads/branch-2.2 d20c64695 -> 00dee3902


[SPARK-20861][ML][PYTHON] Delegate looping over paramMaps to estimators

Changes:

pyspark.ml Estimators can take either a list of param maps or a dict of params. 
This change allows the CrossValidator and TrainValidationSplit Estimators to 
pass through lists of param maps to the underlying estimators so that those 
estimators can handle parallelization when appropriate (eg distributed hyper 
parameter tuning).

Testing:

Existing unit tests.

Author: Bago Amirbekian <b...@databricks.com>

Closes #18077 from MrBago/delegate_params.

(cherry picked from commit 9434280cfd1db94dc9d52bb0ace8283e710e3124)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/00dee390
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/00dee390
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/00dee390

Branch: refs/heads/branch-2.2
Commit: 00dee39029119845d3b744ee70c562cf073ee678
Parents: d20c646
Author: Bago Amirbekian <b...@databricks.com>
Authored: Tue May 23 20:56:01 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue May 23 20:56:12 2017 -0700

----------------------------------------------------------------------
 python/pyspark/ml/tuning.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/00dee390/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ffeb445..b648582 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -18,14 +18,11 @@
 import itertools
 import numpy as np
 
-from pyspark import SparkContext
 from pyspark import since, keyword_only
 from pyspark.ml import Estimator, Model
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
-from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
-from pyspark.ml.common import inherit_doc, _py2java
 
 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 
'TrainValidationSplit',
            'TrainValidationSplitModel']
@@ -232,8 +229,9 @@ class CrossValidator(Estimator, ValidatorParams):
             condition = (df[randCol] >= validateLB) & (df[randCol] < 
validateUB)
             validation = df.filter(condition)
             train = df.filter(~condition)
+            models = est.fit(train, epm)
             for j in range(numModels):
-                model = est.fit(train, epm[j])
+                model = models[j]
                 # TODO: duplicate evaluator to take extra params from input
                 metric = eva.evaluate(model.transform(validation, epm[j]))
                 metrics[j] += metric/nFolds
@@ -388,8 +386,9 @@ class TrainValidationSplit(Estimator, ValidatorParams):
         condition = (df[randCol] >= tRatio)
         validation = df.filter(condition)
         train = df.filter(~condition)
+        models = est.fit(train, epm)
         for j in range(numModels):
-            model = est.fit(train, epm[j])
+            model = models[j]
             metric = eva.evaluate(model.transform(validation, epm[j]))
             metrics[j] += metric
         if eva.isLargerBetter():


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to