Repository: spark
Updated Branches:
  refs/heads/master 5003736ad -> 04614820e


[SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all 
models when fitting: Python API

## What changes were proposed in this pull request?

Add python API for collecting sub-models during 
CrossValidator/TrainValidationSplit fitting.

## How was this patch tested?

UT added.

Author: WeichenXu <weichen...@databricks.com>

Closes #19627 from WeichenXu123/expose-model-list-py.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/04614820
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/04614820
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/04614820

Branch: refs/heads/master
Commit: 04614820e103feeae91299dc90dba1dd628fd485
Parents: 5003736
Author: WeichenXu <weichen...@databricks.com>
Authored: Mon Apr 16 11:31:24 2018 -0500
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Apr 16 11:31:24 2018 -0500

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidator.scala |  11 ++
 .../spark/ml/tuning/TrainValidationSplit.scala  |  11 ++
 .../pyspark/ml/param/_shared_params_code_gen.py |   5 +
 python/pyspark/ml/param/shared.py               |  24 +++++
 python/pyspark/ml/tests.py                      |  78 ++++++++++++++
 python/pyspark/ml/tuning.py                     | 107 ++++++++++++++-----
 python/pyspark/ml/util.py                       |   4 +
 7 files changed, 211 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index a0b507d..c2826dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] (
     this
   }
 
+  // A Python-friendly auxiliary method
+  private[tuning] def setSubModels(subModels: JList[JList[Model[_]]])
+    : CrossValidatorModel = {
+    _subModels = if (subModels != null) {
+      Some(subModels.asScala.toArray.map(_.asScala.toArray))
+    } else {
+      None
+    }
+    this
+  }
+
   /**
    * @return submodels represented in two dimension array. The index of outer 
array is the
    *         fold index, and the index of inner array corresponds to the 
ordering of

http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 88ff0df..8d1b9a8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] (
     this
   }
 
+  // A Python-friendly auxiliary method
+  private[tuning] def setSubModels(subModels: JList[Model[_]])
+    : TrainValidationSplitModel = {
+    _subModels = if (subModels != null) {
+      Some(subModels.asScala.toArray)
+    } else {
+      None
+    }
+    this
+  }
+
   /**
    * @return submodels represented in array. The index of array corresponds to 
the ordering of
    *         estimatorParamMaps

http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py 
b/python/pyspark/ml/param/_shared_params_code_gen.py
index db951d8..6e9e0a3 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -157,6 +157,11 @@ if __name__ == "__main__":
          "TypeConverters.toInt"),
         ("parallelism", "the number of threads to use when running parallel 
algorithms (>= 1).",
          "1", "TypeConverters.toInt"),
+        ("collectSubModels", "Param for whether to collect a list of 
sub-models trained during " +
+         "tuning. If set to false, then only the single best sub-model will be 
available after " +
+         "fitting. If set to true, then all sub-models will be available. 
Warning: For large " +
+         "models, collecting all sub-models can cause OOMs on the Spark 
driver.",
+         "False", "TypeConverters.toBoolean"),
         ("loss", "the loss function to be optimized.", None, 
"TypeConverters.toString")]
 
     code = []

http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py 
b/python/pyspark/ml/param/shared.py
index 474c387..08408ee 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -655,6 +655,30 @@ class HasParallelism(Params):
         return self.getOrDefault(self.parallelism)
 
 
+class HasCollectSubModels(Params):
+    """
+    Mixin for param collectSubModels: Param for whether to collect a list of 
sub-models trained during tuning. If set to false, then only the single best 
sub-model will be available after fitting. If set to true, then all sub-models 
will be available. Warning: For large models, collecting all sub-models can 
cause OOMs on the Spark driver.
+    """
+
+    collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for 
whether to collect a list of sub-models trained during tuning. If set to false, 
then only the single best sub-model will be available after fitting. If set to 
true, then all sub-models will be available. Warning: For large models, 
collecting all sub-models can cause OOMs on the Spark driver.", 
typeConverter=TypeConverters.toBoolean)
+
+    def __init__(self):
+        super(HasCollectSubModels, self).__init__()
+        self._setDefault(collectSubModels=False)
+
+    def setCollectSubModels(self, value):
+        """
+        Sets the value of :py:attr:`collectSubModels`.
+        """
+        return self._set(collectSubModels=value)
+
+    def getCollectSubModels(self):
+        """
+        Gets the value of collectSubModels or its default value.
+        """
+        return self.getOrDefault(self.collectSubModels)
+
+
 class HasLoss(Params):
     """
     Mixin for param loss: the loss function to be optimized.

http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4ce5454..2ec0be6 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1018,6 +1018,50 @@ class CrossValidatorTests(SparkSessionTestCase):
         cvParallelModel = cv.fit(dataset)
         self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics)
 
+    def test_expose_sub_models(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+
+        numFolds = 3
+        cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator,
+                            numFolds=numFolds, collectSubModels=True)
+
+        def checkSubModels(subModels):
+            self.assertEqual(len(subModels), numFolds)
+            for i in range(numFolds):
+                self.assertEqual(len(subModels[i]), len(grid))
+
+        cvModel = cv.fit(dataset)
+        checkSubModels(cvModel.subModels)
+
+        # Test the default value for option "persistSubModel" to be "true"
+        testSubPath = temp_path + "/testCrossValidatorSubModels"
+        savingPathWithSubModels = testSubPath + "cvModel3"
+        cvModel.save(savingPathWithSubModels)
+        cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+        checkSubModels(cvModel3.subModels)
+        cvModel4 = cvModel3.copy()
+        checkSubModels(cvModel4.subModels)
+
+        savingPathWithoutSubModels = testSubPath + "cvModel2"
+        cvModel.write().option("persistSubModels", 
"false").save(savingPathWithoutSubModels)
+        cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
+        self.assertEqual(cvModel2.subModels, None)
+
+        for i in range(numFolds):
+            for j in range(len(grid)):
+                self.assertEqual(cvModel.subModels[i][j].uid, 
cvModel3.subModels[i][j].uid)
+
     def test_save_load_nested_estimator(self):
         temp_path = tempfile.mkdtemp()
         dataset = self.spark.createDataFrame(
@@ -1186,6 +1230,40 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         tvsParallelModel = tvs.fit(dataset)
         self.assertEqual(tvsSerialModel.validationMetrics, 
tvsParallelModel.validationMetrics)
 
+    def test_expose_sub_models(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator,
+                                   collectSubModels=True)
+        tvsModel = tvs.fit(dataset)
+        self.assertEqual(len(tvsModel.subModels), len(grid))
+
+        # Test the default value for option "persistSubModel" to be "true"
+        testSubPath = temp_path + "/testTrainValidationSplitSubModels"
+        savingPathWithSubModels = testSubPath + "cvModel3"
+        tvsModel.save(savingPathWithSubModels)
+        tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
+        self.assertEqual(len(tvsModel3.subModels), len(grid))
+        tvsModel4 = tvsModel3.copy()
+        self.assertEqual(len(tvsModel4.subModels), len(grid))
+
+        savingPathWithoutSubModels = testSubPath + "cvModel2"
+        tvsModel.write().option("persistSubModels", 
"false").save(savingPathWithoutSubModels)
+        tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
+        self.assertEqual(tvsModel2.subModels, None)
+
+        for i in range(len(grid)):
+            self.assertEqual(tvsModel.subModels[i].uid, 
tvsModel3.subModels[i].uid)
+
     def test_save_load_nested_estimator(self):
         # This tests saving and loading the trained model only.
         # Save/load for TrainValidationSplit will be added later: SPARK-13786

http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 545e24c..0c8029f 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark import since, keyword_only
 from pyspark.ml import Estimator, Model
 from pyspark.ml.common import _py2java
 from pyspark.ml.param import Params, Param, TypeConverters
-from pyspark.ml.param.shared import HasParallelism, HasSeed
+from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, 
HasSeed
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
@@ -33,7 +33,7 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 
'CrossValidatorModel', 'TrainVa
            'TrainValidationSplitModel']
 
 
-def _parallelFitTasks(est, train, eva, validation, epm):
+def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
     """
     Creates a list of callables which can be called from different threads to 
fit and evaluate
     an estimator in parallel. Each callable returns an `(index, metric)` pair.
@@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm):
     :param eva: Evaluator, used to compute `metric`
     :param validation: DataFrame, validation data set, used for evaluation.
     :param epm: Sequence of ParamMap, params maps to be used during fitting & 
evaluation.
-    :return: (int, float), an index into `epm` and the associated metric value.
+    :param collectSubModel: Whether to collect sub model.
+    :return: (int, float, subModel), an index into `epm` and the associated 
metric value.
     """
     modelIter = est.fitMultiple(train, epm)
 
     def singleTask():
         index, model = next(modelIter)
         metric = eva.evaluate(model.transform(validation, epm[index]))
-        return index, metric
+        return index, metric, model if collectSubModel else None
 
     return [singleTask] * len(epm)
 
@@ -194,7 +195,8 @@ class ValidatorParams(HasSeed):
         return java_estimator, java_epms, java_evaluator
 
 
-class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, 
MLWritable):
+class CrossValidator(Estimator, ValidatorParams, HasParallelism, 
HasCollectSubModels,
+                     MLReadable, MLWritable):
     """
 
     K-fold cross validation performs model selection by splitting the dataset 
into a set of
@@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
 
     @keyword_only
     def __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,
-                 seed=None, parallelism=1):
+                 seed=None, parallelism=1, collectSubModels=False):
         """
         __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,\
-                 seed=None, parallelism=1)
+                 seed=None, parallelism=1, collectSubModels=False)
         """
         super(CrossValidator, self).__init__()
         self._setDefault(numFolds=3, parallelism=1)
@@ -246,10 +248,10 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
     @keyword_only
     @since("1.4.0")
     def setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,
-                  seed=None, parallelism=1):
+                  seed=None, parallelism=1, collectSubModels=False):
         """
         setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,\
-                  seed=None, parallelism=1):
+                  seed=None, parallelism=1, collectSubModels=False):
         Sets params for cross validator.
         """
         kwargs = self._input_kwargs
@@ -282,6 +284,10 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
         metrics = [0.0] * numModels
 
         pool = ThreadPool(processes=min(self.getParallelism(), numModels))
+        subModels = None
+        collectSubModelsParam = self.getCollectSubModels()
+        if collectSubModelsParam:
+            subModels = [[None for j in range(numModels)] for i in 
range(nFolds)]
 
         for i in range(nFolds):
             validateLB = i * h
@@ -290,9 +296,12 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
             validation = df.filter(condition).cache()
             train = df.filter(~condition).cache()
 
-            tasks = _parallelFitTasks(est, train, eva, validation, epm)
-            for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+            tasks = _parallelFitTasks(est, train, eva, validation, epm, 
collectSubModelsParam)
+            for j, metric, subModel in pool.imap_unordered(lambda f: f(), 
tasks):
                 metrics[j] += (metric / nFolds)
+                if collectSubModelsParam:
+                    subModels[i][j] = subModel
+
             validation.unpersist()
             train.unpersist()
 
@@ -301,7 +310,7 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
         else:
             bestIndex = np.argmin(metrics)
         bestModel = est.fit(dataset, epm[bestIndex])
-        return self._copyValues(CrossValidatorModel(bestModel, metrics))
+        return self._copyValues(CrossValidatorModel(bestModel, metrics, 
subModels))
 
     @since("1.4.0")
     def copy(self, extra=None):
@@ -345,9 +354,11 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
         numFolds = java_stage.getNumFolds()
         seed = java_stage.getSeed()
         parallelism = java_stage.getParallelism()
+        collectSubModels = java_stage.getCollectSubModels()
         # Create a new instance of this stage.
         py_stage = cls(estimator=estimator, estimatorParamMaps=epms, 
evaluator=evaluator,
-                       numFolds=numFolds, seed=seed, parallelism=parallelism)
+                       numFolds=numFolds, seed=seed, parallelism=parallelism,
+                       collectSubModels=collectSubModels)
         py_stage._resetUid(java_stage.uid())
         return py_stage
 
@@ -367,6 +378,7 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
         _java_obj.setSeed(self.getSeed())
         _java_obj.setNumFolds(self.getNumFolds())
         _java_obj.setParallelism(self.getParallelism())
+        _java_obj.setCollectSubModels(self.getCollectSubModels())
 
         return _java_obj
 
@@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
     .. versionadded:: 1.4.0
     """
 
-    def __init__(self, bestModel, avgMetrics=[]):
+    def __init__(self, bestModel, avgMetrics=[], subModels=None):
         super(CrossValidatorModel, self).__init__()
         #: best model from cross validation
         self.bestModel = bestModel
         #: Average cross-validation metrics for each paramMap in
         #: CrossValidator.estimatorParamMaps, in the corresponding order.
         self.avgMetrics = avgMetrics
+        #: sub model list from cross validation
+        self.subModels = subModels
 
     def _transform(self, dataset):
         return self.bestModel.transform(dataset)
@@ -399,6 +413,7 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         and some extra params. This copies the underlying bestModel,
         creates a deep copy of the embedded paramMap, and
         copies the embedded and extra parameters over.
+        It does not copy the extra Params into the subModels.
 
         :param extra: Extra parameters to copy to the new instance
         :return: Copy of this instance
@@ -407,7 +422,8 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
             extra = dict()
         bestModel = self.bestModel.copy(extra)
         avgMetrics = self.avgMetrics
-        return CrossValidatorModel(bestModel, avgMetrics)
+        subModels = self.subModels
+        return CrossValidatorModel(bestModel, avgMetrics, subModels)
 
     @since("2.3.0")
     def write(self):
@@ -426,13 +442,17 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         Given a Java CrossValidatorModel, create and return a Python wrapper 
of it.
         Used for ML persistence.
         """
-
         bestModel = JavaParams._from_java(java_stage.bestModel())
         estimator, epms, evaluator = super(CrossValidatorModel, 
cls)._from_java_impl(java_stage)
 
         py_stage = cls(bestModel=bestModel).setEstimator(estimator)
         py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
 
+        if java_stage.hasSubModels():
+            py_stage.subModels = [[JavaParams._from_java(sub_model)
+                                   for sub_model in fold_sub_models]
+                                  for fold_sub_models in 
java_stage.subModels()]
+
         py_stage._resetUid(java_stage.uid())
         return py_stage
 
@@ -454,10 +474,16 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         _java_obj.set("evaluator", evaluator)
         _java_obj.set("estimator", estimator)
         _java_obj.set("estimatorParamMaps", epms)
+
+        if self.subModels is not None:
+            java_sub_models = [[sub_model._to_java() for sub_model in 
fold_sub_models]
+                               for fold_sub_models in self.subModels]
+            _java_obj.setSubModels(java_sub_models)
         return _java_obj
 
 
-class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, 
MLReadable, MLWritable):
+class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, 
HasCollectSubModels,
+                           MLReadable, MLWritable):
     """
     .. note:: Experimental
 
@@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
 
     @keyword_only
     def __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, trainRatio=0.75,
-                 parallelism=1, seed=None):
+                 parallelism=1, collectSubModels=False, seed=None):
         """
         __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, trainRatio=0.75,\
-                 parallelism=1, seed=None)
+                 parallelism=1, collectSubModels=False, seed=None)
         """
         super(TrainValidationSplit, self).__init__()
         self._setDefault(trainRatio=0.75, parallelism=1)
@@ -505,10 +531,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
     @since("2.0.0")
     @keyword_only
     def setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, trainRatio=0.75,
-                  parallelism=1, seed=None):
+                  parallelism=1, collectSubModels=False, seed=None):
         """
         setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, trainRatio=0.75,\
-                  parallelism=1, seed=None):
+                  parallelism=1, collectSubModels=False, seed=None):
         Sets params for the train validation split.
         """
         kwargs = self._input_kwargs
@@ -541,11 +567,19 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
         validation = df.filter(condition).cache()
         train = df.filter(~condition).cache()
 
-        tasks = _parallelFitTasks(est, train, eva, validation, epm)
+        subModels = None
+        collectSubModelsParam = self.getCollectSubModels()
+        if collectSubModelsParam:
+            subModels = [None for i in range(numModels)]
+
+        tasks = _parallelFitTasks(est, train, eva, validation, epm, 
collectSubModelsParam)
         pool = ThreadPool(processes=min(self.getParallelism(), numModels))
         metrics = [None] * numModels
-        for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+        for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
             metrics[j] = metric
+            if collectSubModelsParam:
+                subModels[j] = subModel
+
         train.unpersist()
         validation.unpersist()
 
@@ -554,7 +588,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
         else:
             bestIndex = np.argmin(metrics)
         bestModel = est.fit(dataset, epm[bestIndex])
-        return self._copyValues(TrainValidationSplitModel(bestModel, metrics))
+        return self._copyValues(TrainValidationSplitModel(bestModel, metrics, 
subModels))
 
     @since("2.0.0")
     def copy(self, extra=None):
@@ -598,9 +632,11 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
         trainRatio = java_stage.getTrainRatio()
         seed = java_stage.getSeed()
         parallelism = java_stage.getParallelism()
+        collectSubModels = java_stage.getCollectSubModels()
         # Create a new instance of this stage.
         py_stage = cls(estimator=estimator, estimatorParamMaps=epms, 
evaluator=evaluator,
-                       trainRatio=trainRatio, seed=seed, 
parallelism=parallelism)
+                       trainRatio=trainRatio, seed=seed, 
parallelism=parallelism,
+                       collectSubModels=collectSubModels)
         py_stage._resetUid(java_stage.uid())
         return py_stage
 
@@ -620,7 +656,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
         _java_obj.setTrainRatio(self.getTrainRatio())
         _java_obj.setSeed(self.getSeed())
         _java_obj.setParallelism(self.getParallelism())
-
+        _java_obj.setCollectSubModels(self.getCollectSubModels())
         return _java_obj
 
 
@@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
     .. versionadded:: 2.0.0
     """
 
-    def __init__(self, bestModel, validationMetrics=[]):
+    def __init__(self, bestModel, validationMetrics=[], subModels=None):
         super(TrainValidationSplitModel, self).__init__()
-        #: best model from cross validation
+        #: best model from train validation split
         self.bestModel = bestModel
         #: evaluated validation metrics
         self.validationMetrics = validationMetrics
+        #: sub models from train validation split
+        self.subModels = subModels
 
     def _transform(self, dataset):
         return self.bestModel.transform(dataset)
@@ -651,6 +689,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         creates a deep copy of the embedded paramMap, and
         copies the embedded and extra parameters over.
         And, this creates a shallow copy of the validationMetrics.
+        It does not copy the extra Params into the subModels.
 
         :param extra: Extra parameters to copy to the new instance
         :return: Copy of this instance
@@ -659,7 +698,8 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
             extra = dict()
         bestModel = self.bestModel.copy(extra)
         validationMetrics = list(self.validationMetrics)
-        return TrainValidationSplitModel(bestModel, validationMetrics)
+        subModels = self.subModels
+        return TrainValidationSplitModel(bestModel, validationMetrics, 
subModels)
 
     @since("2.3.0")
     def write(self):
@@ -687,6 +727,10 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         py_stage = cls(bestModel=bestModel).setEstimator(estimator)
         py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
 
+        if java_stage.hasSubModels():
+            py_stage.subModels = [JavaParams._from_java(sub_model)
+                                  for sub_model in java_stage.subModels()]
+
         py_stage._resetUid(java_stage.uid())
         return py_stage
 
@@ -708,6 +752,11 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         _java_obj.set("evaluator", evaluator)
         _java_obj.set("estimator", estimator)
         _java_obj.set("estimatorParamMaps", epms)
+
+        if self.subModels is not None:
+            java_sub_models = [sub_model._to_java() for sub_model in 
self.subModels]
+            _java_obj.setSubModels(java_sub_models)
+
         return _java_obj
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index c3c47bd..a486c6a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -169,6 +169,10 @@ class JavaMLWriter(MLWriter):
         self._jwrite.overwrite()
         return self
 
+    def option(self, key, value):
+        self._jwrite.option(key, value)
+        return self
+
     def context(self, sqlContext):
         """
         Sets the SQL context to use for saving.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to