Repository: spark Updated Branches: refs/heads/master 5003736ad -> 04614820e
[SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API ## What changes were proposed in this pull request? Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting. ## How was this patch tested? UT added. Author: WeichenXu <weichen...@databricks.com> Closes #19627 from WeichenXu123/expose-model-list-py. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/04614820 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/04614820 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/04614820 Branch: refs/heads/master Commit: 04614820e103feeae91299dc90dba1dd628fd485 Parents: 5003736 Author: WeichenXu <weichen...@databricks.com> Authored: Mon Apr 16 11:31:24 2018 -0500 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Apr 16 11:31:24 2018 -0500 ---------------------------------------------------------------------- .../apache/spark/ml/tuning/CrossValidator.scala | 11 ++ .../spark/ml/tuning/TrainValidationSplit.scala | 11 ++ .../pyspark/ml/param/_shared_params_code_gen.py | 5 + python/pyspark/ml/param/shared.py | 24 +++++ python/pyspark/ml/tests.py | 78 ++++++++++++++ python/pyspark/ml/tuning.py | 107 ++++++++++++++----- python/pyspark/ml/util.py | 4 + 7 files changed, 211 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a0b507d..c2826dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[JList[Model[_]]]) + : CrossValidatorModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray.map(_.asScala.toArray)) + } else { + None + } + this + } + /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 88ff0df..8d1b9a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[Model[_]]) + : TrainValidationSplitModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray) + } else { + None + } + this + } + /** * @return submodels represented in array. The index of array corresponds to the ordering of * estimatorParamMaps http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/param/_shared_params_code_gen.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index db951d8..6e9e0a3 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -157,6 +157,11 @@ if __name__ == "__main__": "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", "1", "TypeConverters.toInt"), + ("collectSubModels", "Param for whether to collect a list of sub-models trained during " + + "tuning. If set to false, then only the single best sub-model will be available after " + + "fitting. If set to true, then all sub-models will be available. Warning: For large " + + "models, collecting all sub-models can cause OOMs on the Spark driver.", + "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/param/shared.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 474c387..08408ee 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -655,6 +655,30 @@ class HasParallelism(Params): return self.getOrDefault(self.parallelism) +class HasCollectSubModels(Params): + """ + Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. + """ + + collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean) + + def __init__(self): + super(HasCollectSubModels, self).__init__() + self._setDefault(collectSubModels=False) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + + def getCollectSubModels(self): + """ + Gets the value of collectSubModels or its default value. + """ + return self.getOrDefault(self.collectSubModels) + + class HasLoss(Params): """ Mixin for param loss: the loss function to be optimized. http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4ce5454..2ec0be6 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1018,6 +1018,50 @@ class CrossValidatorTests(SparkSessionTestCase): cvParallelModel = cv.fit(dataset) self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -1186,6 +1230,40 @@ class TrainValidationSplitTests(SparkSessionTestCase): tvsParallelModel = tvs.fit(dataset) self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 545e24c..0c8029f 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasParallelism, HasSeed +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -33,7 +33,7 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainVa 'TrainValidationSplitModel'] -def _parallelFitTasks(est, train, eva, validation, epm): +def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm): :param eva: Evaluator, used to compute `metric` :param validation: DataFrame, validation data set, used for evaluation. :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. - :return: (int, float), an index into `epm` and the associated metric value. + :param collectSubModel: Whether to collect sub model. + :return: (int, float, subModel), an index into `epm` and the associated metric value. """ modelIter = est.fitMultiple(train, epm) def singleTask(): index, model = next(modelIter) metric = eva.evaluate(model.transform(validation, epm[index])) - return index, metric + return index, metric, model if collectSubModel else None return [singleTask] * len(epm) @@ -194,7 +195,8 @@ class ValidatorParams(HasSeed): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1) + seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) @@ -246,10 +248,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -282,6 +284,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h @@ -290,9 +296,12 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) + if collectSubModelsParam: + subModels[i][j] = subModel + validation.unpersist() train.unpersist() @@ -301,7 +310,7 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics)) + return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels)) @since("1.4.0") def copy(self, extra=None): @@ -345,9 +354,11 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed, parallelism=parallelism) + numFolds=numFolds, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -367,6 +378,7 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[]): + def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics + #: sub model list from cross validation + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -399,6 +413,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -407,7 +422,8 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics - return CrossValidatorModel(bestModel, avgMetrics) + subModels = self.subModels + return CrossValidatorModel(bestModel, avgMetrics, subModels) @since("2.3.0") def write(self): @@ -426,13 +442,17 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ - bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [[JavaParams._from_java(sub_model) + for sub_model in fold_sub_models] + for fold_sub_models in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -454,10 +474,16 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels] + _java_obj.setSubModels(java_sub_models) return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ .. note:: Experimental @@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None) + parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) @@ -505,10 +531,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -541,11 +567,19 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [None for i in range(numModels)] + + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric + if collectSubModelsParam: + subModels[j] = subModel + train.unpersist() validation.unpersist() @@ -554,7 +588,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) @since("2.0.0") def copy(self, extra=None): @@ -598,9 +632,11 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed, parallelism=parallelism) + trainRatio=trainRatio, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -620,7 +656,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) - + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[]): + def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() - #: best model from cross validation + #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics + #: sub models from train validation split + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -651,6 +689,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -659,7 +698,8 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) - return TrainValidationSplitModel(bestModel, validationMetrics) + subModels = self.subModels + return TrainValidationSplitModel(bestModel, validationMetrics, subModels) @since("2.3.0") def write(self): @@ -687,6 +727,10 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [JavaParams._from_java(sub_model) + for sub_model in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -708,6 +752,11 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + _java_obj.setSubModels(java_sub_models) + return _java_obj http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3c47bd..a486c6a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -169,6 +169,10 @@ class JavaMLWriter(MLWriter): self._jwrite.overwrite() return self + def option(self, key, value): + self._jwrite.option(key, value) + return self + def context(self, sqlContext): """ Sets the SQL context to use for saving. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org